Restructure: Move 52 files into 7 domain packages
CI / go-lint (push) Has been skipped
CI / python-lint (push) Has been skipped
CI / nodejs-lint (push) Has been skipped
CI / test-go-school (push) Successful in 28s
CI / test-go-edu-search (push) Successful in 28s
CI / test-python-klausur (push) Failing after 2m22s
CI / test-python-agent-core (push) Successful in 21s
CI / test-nodejs-website (push) Successful in 23s

korrektur/ zeugnis/ admin/ compliance/ worksheet/ training/ metrics/
52 shims, relative imports, RAG untouched.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-04-25 22:10:48 +02:00
parent 0504d22b8e
commit 165c493d1e
111 changed files with 11859 additions and 11609 deletions
@@ -0,0 +1,6 @@
"""
admin package — admin APIs for NiBiS, RAG, templates.
Backward-compatible re-exports: consumers can still use
``from admin_api import ...`` etc. via the shim files in backend/.
"""
+33
View File
@@ -0,0 +1,33 @@
"""
Admin API for NiBiS Data Management (barrel re-export)
This module was split into:
- admin_nibis.py (NiBiS ingestion, search, stats)
- admin_rag.py (RAG upload, metrics, storage)
- admin_templates.py (Legal templates ingestion, search)
The `router` object is assembled here by including all sub-routers.
Importers that did `from admin_api import router` continue to work.
"""
from fastapi import APIRouter
from .nibis import router as _nibis_router
from .rag import router as _rag_router
from .templates import router as _templates_router
# Re-export internal state for test importers
from .nibis import ( # noqa: F401
_ingestion_status,
NiBiSSearchRequest,
search_nibis,
)
from .rag import _upload_history # noqa: F401
from .templates import _templates_ingestion_status # noqa: F401
# Assemble the combined router.
# All sub-routers use prefix="/api/v1/admin", so include without extra prefix.
router = APIRouter()
router.include_router(_nibis_router)
router.include_router(_rag_router)
router.include_router(_templates_router)
+316
View File
@@ -0,0 +1,316 @@
"""
Admin API - NiBiS Ingestion & Search
Endpoints for NiBiS data discovery, ingestion, search, and statistics.
Extracted from admin_api.py for file-size compliance.
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime
from nibis_ingestion import (
run_ingestion,
discover_documents,
extract_zip_files,
DOCS_BASE_PATH,
)
from qdrant_service import QdrantService, search_nibis_eh, get_qdrant_client
from eh_pipeline import generate_single_embedding
router = APIRouter(prefix="/api/v1/admin", tags=["Admin"])
# Store for background task status
_ingestion_status: Dict = {
"running": False,
"last_run": None,
"last_result": None,
}
# =============================================================================
# Models
# =============================================================================
class IngestionRequest(BaseModel):
ewh_only: bool = True
year_filter: Optional[int] = None
subject_filter: Optional[str] = None
class IngestionStatus(BaseModel):
running: bool
last_run: Optional[str]
documents_indexed: Optional[int]
chunks_created: Optional[int]
errors: Optional[List[str]]
class NiBiSSearchRequest(BaseModel):
query: str
year: Optional[int] = None
subject: Optional[str] = None
niveau: Optional[str] = None
limit: int = 5
class NiBiSSearchResult(BaseModel):
id: str
score: float
text: str
year: Optional[int]
subject: Optional[str]
niveau: Optional[str]
task_number: Optional[int]
class DataSourceStats(BaseModel):
source_dir: str
year: int
document_count: int
subjects: List[str]
# =============================================================================
# Endpoints
# =============================================================================
@router.get("/nibis/status", response_model=IngestionStatus)
async def get_ingestion_status():
"""Get status of NiBiS ingestion pipeline."""
last_result = _ingestion_status.get("last_result") or {}
return IngestionStatus(
running=_ingestion_status["running"],
last_run=_ingestion_status.get("last_run"),
documents_indexed=last_result.get("documents_indexed"),
chunks_created=last_result.get("chunks_created"),
errors=(last_result.get("errors") or [])[:10],
)
@router.post("/nibis/extract-zips")
async def extract_zip_files_endpoint():
"""Extract all ZIP files in za-download directories."""
try:
extracted = extract_zip_files(DOCS_BASE_PATH)
return {
"status": "success",
"extracted_count": len(extracted),
"directories": [str(d) for d in extracted],
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/nibis/discover")
async def discover_nibis_documents(
ewh_only: bool = Query(True, description="Only return Erwartungshorizonte"),
year: Optional[int] = Query(None, description="Filter by year"),
subject: Optional[str] = Query(None, description="Filter by subject"),
):
"""
Discover available NiBiS documents without indexing.
Useful for previewing what will be indexed.
"""
try:
documents = discover_documents(DOCS_BASE_PATH, ewh_only=ewh_only)
# Apply filters
if year:
documents = [d for d in documents if d.year == year]
if subject:
documents = [d for d in documents if subject.lower() in d.subject.lower()]
# Group by year and subject
by_year: Dict[int, int] = {}
by_subject: Dict[str, int] = {}
for doc in documents:
by_year[doc.year] = by_year.get(doc.year, 0) + 1
by_subject[doc.subject] = by_subject.get(doc.subject, 0) + 1
return {
"total_documents": len(documents),
"by_year": dict(sorted(by_year.items())),
"by_subject": dict(sorted(by_subject.items(), key=lambda x: -x[1])),
"sample_documents": [
{
"id": d.id,
"filename": d.raw_filename,
"year": d.year,
"subject": d.subject,
"niveau": d.niveau,
"doc_type": d.doc_type,
}
for d in documents[:20]
],
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/nibis/ingest")
async def start_ingestion(
request: IngestionRequest,
background_tasks: BackgroundTasks,
):
"""
Start NiBiS data ingestion in background.
"""
if _ingestion_status["running"]:
raise HTTPException(
status_code=409,
detail="Ingestion already running. Check /nibis/status for progress."
)
async def run_ingestion_task():
global _ingestion_status
_ingestion_status["running"] = True
_ingestion_status["last_run"] = datetime.now().isoformat()
try:
result = await run_ingestion(
ewh_only=request.ewh_only,
dry_run=False,
year_filter=request.year_filter,
subject_filter=request.subject_filter,
)
_ingestion_status["last_result"] = result
except Exception as e:
_ingestion_status["last_result"] = {"error": str(e), "errors": [str(e)]}
finally:
_ingestion_status["running"] = False
background_tasks.add_task(run_ingestion_task)
return {
"status": "started",
"message": "Ingestion started in background. Check /nibis/status for progress.",
"filters": {
"ewh_only": request.ewh_only,
"year": request.year_filter,
"subject": request.subject_filter,
},
}
@router.post("/nibis/search", response_model=List[NiBiSSearchResult])
async def search_nibis(request: NiBiSSearchRequest):
"""
Semantic search in NiBiS Erwartungshorizonte.
"""
try:
query_embedding = await generate_single_embedding(request.query)
if not query_embedding:
raise HTTPException(status_code=500, detail="Failed to generate embedding")
results = await search_nibis_eh(
query_embedding=query_embedding,
year=request.year,
subject=request.subject,
niveau=request.niveau,
limit=request.limit,
)
return [
NiBiSSearchResult(
id=r["id"],
score=r["score"],
text=r.get("text", "")[:500],
year=r.get("year"),
subject=r.get("subject"),
niveau=r.get("niveau"),
task_number=r.get("task_number"),
)
for r in results
]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/nibis/collections")
async def get_collections_info():
"""Get information about all Qdrant collections."""
try:
client = get_qdrant_client()
collections = client.get_collections().collections
result = []
for c in collections:
try:
info = client.get_collection(c.name)
result.append({
"name": c.name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value,
})
except Exception as e:
result.append({
"name": c.name,
"error": str(e),
})
return {"collections": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/nibis/stats")
async def get_nibis_stats():
"""Get detailed statistics about indexed NiBiS data."""
try:
qdrant = QdrantService()
stats = await qdrant.get_stats("bp_nibis_eh")
if "error" in stats:
return {
"indexed": False,
"message": "NiBiS collection not yet created. Run ingestion first.",
}
client = get_qdrant_client()
scroll_result = client.scroll(
collection_name="bp_nibis_eh",
limit=1000,
with_payload=True,
with_vectors=False,
)
years = set()
subjects = set()
niveaus = set()
for point in scroll_result[0]:
if point.payload:
if "year" in point.payload:
years.add(point.payload["year"])
if "subject" in point.payload:
subjects.add(point.payload["subject"])
if "niveau" in point.payload:
niveaus.add(point.payload["niveau"])
return {
"indexed": True,
"total_chunks": stats.get("points_count", 0),
"years": sorted(list(years)),
"subjects": sorted(list(subjects)),
"niveaus": sorted(list(niveaus)),
}
except Exception as e:
return {
"indexed": False,
"error": str(e),
}
@router.delete("/nibis/collection")
async def delete_nibis_collection():
"""Delete the entire NiBiS collection. WARNING: removes all indexed data!"""
try:
client = get_qdrant_client()
client.delete_collection("bp_nibis_eh")
return {"status": "deleted", "collection": "bp_nibis_eh"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+281
View File
@@ -0,0 +1,281 @@
"""
Admin API - RAG Upload & Metrics
Endpoints for uploading documents, tracking uploads, RAG metrics,
search feedback, storage stats, and service initialization.
Extracted from admin_api.py for file-size compliance.
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query, UploadFile, File, Form
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime
from pathlib import Path
import zipfile
import tempfile
import os
from nibis_ingestion import run_ingestion, DOCS_BASE_PATH
# Import ingestion status from nibis module for auto-ingest
from .nibis import _ingestion_status
# Optional: MinIO and PostgreSQL integrations
try:
from minio_storage import upload_rag_document, get_storage_stats, init_minio_bucket
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
try:
from metrics_db import (
init_metrics_tables, store_feedback, log_search, log_upload,
calculate_metrics, get_recent_feedback, get_upload_history
)
METRICS_DB_AVAILABLE = True
except ImportError:
METRICS_DB_AVAILABLE = False
router = APIRouter(prefix="/api/v1/admin", tags=["Admin"])
# Upload directory configuration
RAG_UPLOAD_BASE = Path(os.getenv("RAG_UPLOAD_BASE", str(DOCS_BASE_PATH)))
# Store for upload tracking
_upload_history: List[Dict] = []
class UploadResult(BaseModel):
status: str
files_received: int
pdfs_extracted: int
target_directory: str
errors: List[str]
@router.post("/rag/upload", response_model=UploadResult)
async def upload_rag_documents(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
collection: str = Form(default="bp_nibis_eh"),
year: Optional[int] = Form(default=None),
auto_ingest: bool = Form(default=False),
):
"""
Upload documents for RAG indexing.
Supports:
- ZIP archives (automatically extracted)
- Individual PDF files
"""
errors = []
pdfs_extracted = 0
# Determine target year
target_year = year or datetime.now().year
# Target directory: za-download/YYYY/
target_dir = RAG_UPLOAD_BASE / "za-download" / str(target_year)
target_dir.mkdir(parents=True, exist_ok=True)
try:
filename = file.filename or "upload"
if filename.lower().endswith(".zip"):
# Handle ZIP file
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
with zipfile.ZipFile(tmp_path, 'r') as zf:
for member in zf.namelist():
if member.lower().endswith(".pdf") and not member.startswith("__MACOSX"):
pdf_name = Path(member).name
if pdf_name:
target_path = target_dir / pdf_name
with zf.open(member) as src:
with open(target_path, 'wb') as dst:
dst.write(src.read())
pdfs_extracted += 1
finally:
os.unlink(tmp_path)
elif filename.lower().endswith(".pdf"):
target_path = target_dir / filename
content = await file.read()
with open(target_path, 'wb') as f:
f.write(content)
pdfs_extracted = 1
else:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {filename}. Only .zip and .pdf are allowed."
)
# Track upload in memory
upload_record = {
"timestamp": datetime.now().isoformat(),
"filename": filename,
"collection": collection,
"year": target_year,
"pdfs_extracted": pdfs_extracted,
"target_directory": str(target_dir),
}
_upload_history.append(upload_record)
# Keep only last 100 uploads in memory
if len(_upload_history) > 100:
_upload_history.pop(0)
# Store in PostgreSQL if available
if METRICS_DB_AVAILABLE:
await log_upload(
filename=filename,
collection_name=collection,
year=target_year,
pdfs_extracted=pdfs_extracted,
minio_path=str(target_dir),
)
# Auto-ingest if requested
if auto_ingest and not _ingestion_status["running"]:
async def run_auto_ingest():
global _ingestion_status
_ingestion_status["running"] = True
_ingestion_status["last_run"] = datetime.now().isoformat()
try:
result = await run_ingestion(
ewh_only=True,
dry_run=False,
year_filter=target_year,
)
_ingestion_status["last_result"] = result
except Exception as e:
_ingestion_status["last_result"] = {"error": str(e), "errors": [str(e)]}
finally:
_ingestion_status["running"] = False
background_tasks.add_task(run_auto_ingest)
return UploadResult(
status="success",
files_received=1,
pdfs_extracted=pdfs_extracted,
target_directory=str(target_dir),
errors=errors,
)
except HTTPException:
raise
except Exception as e:
errors.append(str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.get("/rag/upload/history")
async def get_upload_history_endpoint(limit: int = Query(default=20, le=100)):
"""Get recent upload history."""
return {
"uploads": _upload_history[-limit:][::-1],
"total": len(_upload_history),
}
@router.get("/rag/metrics")
async def get_rag_metrics(
collection: Optional[str] = Query(default=None),
days: int = Query(default=7, le=90),
):
"""Get RAG quality metrics."""
if METRICS_DB_AVAILABLE:
metrics = await calculate_metrics(collection_name=collection, days=days)
if metrics.get("connected"):
return metrics
# Fallback: Return placeholder metrics
return {
"precision_at_5": 0.78,
"recall_at_10": 0.85,
"mrr": 0.72,
"avg_latency_ms": 52,
"total_ratings": len(_upload_history),
"error_rate": 0.3,
"score_distribution": {
"0.9+": 23,
"0.7-0.9": 41,
"0.5-0.7": 28,
"<0.5": 8,
},
"note": "Placeholder metrics - PostgreSQL not connected",
"connected": False,
}
@router.post("/rag/search/feedback")
async def submit_search_feedback(
result_id: str = Form(...),
rating: int = Form(..., ge=1, le=5),
notes: Optional[str] = Form(default=None),
query: Optional[str] = Form(default=None),
collection: Optional[str] = Form(default=None),
score: Optional[float] = Form(default=None),
):
"""Submit feedback for a search result."""
feedback_record = {
"timestamp": datetime.now().isoformat(),
"result_id": result_id,
"rating": rating,
"notes": notes,
}
stored = False
if METRICS_DB_AVAILABLE:
stored = await store_feedback(
result_id=result_id,
rating=rating,
query_text=query,
collection_name=collection,
score=score,
notes=notes,
)
return {
"status": "stored" if stored else "received",
"feedback": feedback_record,
"persisted": stored,
}
@router.get("/rag/storage/stats")
async def get_storage_statistics():
"""Get MinIO storage statistics."""
if MINIO_AVAILABLE:
stats = await get_storage_stats()
return stats
return {
"error": "MinIO not available",
"connected": False,
}
@router.post("/rag/init")
async def initialize_rag_services():
"""Initialize RAG services (MinIO bucket, PostgreSQL tables)."""
results = {
"minio": False,
"postgres": False,
}
if MINIO_AVAILABLE:
results["minio"] = await init_minio_bucket()
if METRICS_DB_AVAILABLE:
results["postgres"] = await init_metrics_tables()
return {
"status": "initialized",
"services": results,
}
+389
View File
@@ -0,0 +1,389 @@
"""
Admin API - Legal Templates
Endpoints for legal template ingestion, search, source management,
license info, and collection management.
Extracted from admin_api.py for file-size compliance.
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime
from eh_pipeline import generate_single_embedding
# Import legal templates modules
try:
from legal_templates_ingestion import (
LegalTemplatesIngestion,
LEGAL_TEMPLATES_COLLECTION,
)
from template_sources import (
TEMPLATE_SOURCES,
TEMPLATE_TYPES,
JURISDICTIONS,
LicenseType,
get_enabled_sources,
get_sources_by_priority,
)
from qdrant_service import (
search_legal_templates,
get_legal_templates_stats,
init_legal_templates_collection,
)
LEGAL_TEMPLATES_AVAILABLE = True
except ImportError as e:
print(f"Legal templates module not available: {e}")
LEGAL_TEMPLATES_AVAILABLE = False
router = APIRouter(prefix="/api/v1/admin", tags=["Admin"])
# Store for templates ingestion status
_templates_ingestion_status: Dict = {
"running": False,
"last_run": None,
"current_source": None,
"results": {},
}
class TemplatesSearchRequest(BaseModel):
query: str
template_type: Optional[str] = None
license_types: Optional[List[str]] = None
language: Optional[str] = None
jurisdiction: Optional[str] = None
attribution_required: Optional[bool] = None
limit: int = 10
class TemplatesSearchResult(BaseModel):
id: str
score: float
text: str
document_title: Optional[str]
template_type: Optional[str]
clause_category: Optional[str]
language: Optional[str]
jurisdiction: Optional[str]
license_id: Optional[str]
license_name: Optional[str]
attribution_required: Optional[bool]
attribution_text: Optional[str]
source_name: Optional[str]
source_url: Optional[str]
placeholders: Optional[List[str]]
is_complete_document: Optional[bool]
requires_customization: Optional[bool]
class SourceIngestRequest(BaseModel):
source_name: str
@router.get("/templates/status")
async def get_templates_status():
"""Get status of legal templates collection and ingestion."""
if not LEGAL_TEMPLATES_AVAILABLE:
return {
"available": False,
"error": "Legal templates module not available",
}
try:
stats = await get_legal_templates_stats()
return {
"available": True,
"collection": LEGAL_TEMPLATES_COLLECTION,
"ingestion": {
"running": _templates_ingestion_status["running"],
"last_run": _templates_ingestion_status.get("last_run"),
"current_source": _templates_ingestion_status.get("current_source"),
"results": _templates_ingestion_status.get("results", {}),
},
"stats": stats,
}
except Exception as e:
return {
"available": True,
"error": str(e),
"ingestion": _templates_ingestion_status,
}
@router.get("/templates/sources")
async def get_templates_sources():
"""Get list of all template sources with their configuration."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
sources = []
for source in TEMPLATE_SOURCES:
sources.append({
"name": source.name,
"description": source.description,
"license_type": source.license_type.value,
"license_name": source.license_info.name,
"template_types": source.template_types,
"languages": source.languages,
"jurisdiction": source.jurisdiction,
"repo_url": source.repo_url,
"web_url": source.web_url,
"priority": source.priority,
"enabled": source.enabled,
"attribution_required": source.license_info.attribution_required,
})
return {
"sources": sources,
"total": len(sources),
"enabled": len([s for s in TEMPLATE_SOURCES if s.enabled]),
"template_types": TEMPLATE_TYPES,
"jurisdictions": JURISDICTIONS,
}
@router.get("/templates/licenses")
async def get_templates_licenses():
"""Get license statistics for indexed templates."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
try:
stats = await get_legal_templates_stats()
return {
"licenses": stats.get("licenses", {}),
"total_chunks": stats.get("points_count", 0),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/templates/ingest")
async def start_templates_ingestion(
background_tasks: BackgroundTasks,
max_priority: int = Query(default=3, ge=1, le=5, description="Maximum priority level (1=highest)"),
):
"""
Start legal templates ingestion in background.
Ingests all enabled sources up to the specified priority level.
"""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
if _templates_ingestion_status["running"]:
raise HTTPException(
status_code=409,
detail="Templates ingestion already running. Check /templates/status for progress."
)
async def run_templates_ingestion():
global _templates_ingestion_status
_templates_ingestion_status["running"] = True
_templates_ingestion_status["last_run"] = datetime.now().isoformat()
_templates_ingestion_status["results"] = {}
try:
ingestion = LegalTemplatesIngestion()
sources = get_sources_by_priority(max_priority)
for source in sources:
_templates_ingestion_status["current_source"] = source.name
try:
status = await ingestion.ingest_source(source)
_templates_ingestion_status["results"][source.name] = {
"status": status.status,
"documents_found": status.documents_found,
"chunks_indexed": status.chunks_indexed,
"errors": status.errors[:5] if status.errors else [],
}
except Exception as e:
_templates_ingestion_status["results"][source.name] = {
"status": "failed",
"error": str(e),
}
await ingestion.close()
except Exception as e:
_templates_ingestion_status["results"]["_global_error"] = str(e)
finally:
_templates_ingestion_status["running"] = False
_templates_ingestion_status["current_source"] = None
background_tasks.add_task(run_templates_ingestion)
sources = get_sources_by_priority(max_priority)
return {
"status": "started",
"message": f"Ingesting {len(sources)} sources up to priority {max_priority}",
"sources": [s.name for s in sources],
}
@router.post("/templates/ingest-source")
async def ingest_single_source(
request: SourceIngestRequest,
background_tasks: BackgroundTasks,
):
"""Ingest a single template source by name."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
source = next((s for s in TEMPLATE_SOURCES if s.name == request.source_name), None)
if not source:
raise HTTPException(
status_code=404,
detail=f"Source not found: {request.source_name}. Use /templates/sources to list available sources."
)
if not source.enabled:
raise HTTPException(
status_code=400,
detail=f"Source is disabled: {request.source_name}"
)
if _templates_ingestion_status["running"]:
raise HTTPException(
status_code=409,
detail="Templates ingestion already running."
)
async def run_single_ingestion():
global _templates_ingestion_status
_templates_ingestion_status["running"] = True
_templates_ingestion_status["current_source"] = source.name
_templates_ingestion_status["last_run"] = datetime.now().isoformat()
try:
ingestion = LegalTemplatesIngestion()
status = await ingestion.ingest_source(source)
_templates_ingestion_status["results"][source.name] = {
"status": status.status,
"documents_found": status.documents_found,
"chunks_indexed": status.chunks_indexed,
"errors": status.errors[:5] if status.errors else [],
}
await ingestion.close()
except Exception as e:
_templates_ingestion_status["results"][source.name] = {
"status": "failed",
"error": str(e),
}
finally:
_templates_ingestion_status["running"] = False
_templates_ingestion_status["current_source"] = None
background_tasks.add_task(run_single_ingestion)
return {
"status": "started",
"source": source.name,
"license": source.license_type.value,
"template_types": source.template_types,
}
@router.post("/templates/search", response_model=List[TemplatesSearchResult])
async def search_templates(request: TemplatesSearchRequest):
"""Semantic search in legal templates collection."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
try:
query_embedding = await generate_single_embedding(request.query)
if not query_embedding:
raise HTTPException(status_code=500, detail="Failed to generate embedding")
results = await search_legal_templates(
query_embedding=query_embedding,
template_type=request.template_type,
license_types=request.license_types,
language=request.language,
jurisdiction=request.jurisdiction,
attribution_required=request.attribution_required,
limit=request.limit,
)
return [
TemplatesSearchResult(
id=r["id"],
score=r["score"],
text=r.get("text", "")[:1000],
document_title=r.get("document_title"),
template_type=r.get("template_type"),
clause_category=r.get("clause_category"),
language=r.get("language"),
jurisdiction=r.get("jurisdiction"),
license_id=r.get("license_id"),
license_name=r.get("license_name"),
attribution_required=r.get("attribution_required"),
attribution_text=r.get("attribution_text"),
source_name=r.get("source_name"),
source_url=r.get("source_url"),
placeholders=r.get("placeholders"),
is_complete_document=r.get("is_complete_document"),
requires_customization=r.get("requires_customization"),
)
for r in results
]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/templates/reset")
async def reset_templates_collection():
"""Delete and recreate the legal templates collection."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
if _templates_ingestion_status["running"]:
raise HTTPException(
status_code=409,
detail="Cannot reset while ingestion is running"
)
try:
ingestion = LegalTemplatesIngestion()
ingestion.reset_collection()
await ingestion.close()
_templates_ingestion_status["results"] = {}
return {
"status": "reset",
"collection": LEGAL_TEMPLATES_COLLECTION,
"message": "Collection deleted and recreated. Run ingestion to populate.",
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/templates/source/{source_name}")
async def delete_templates_source(source_name: str):
"""Delete all templates from a specific source."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
try:
from qdrant_service import delete_legal_templates_by_source
count = await delete_legal_templates_by_source(source_name)
if source_name in _templates_ingestion_status.get("results", {}):
del _templates_ingestion_status["results"][source_name]
return {
"status": "deleted",
"source": source_name,
"chunks_deleted": count,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+4 -33
View File
@@ -1,33 +1,4 @@
"""
Admin API for NiBiS Data Management (barrel re-export)
This module was split into:
- admin_nibis.py (NiBiS ingestion, search, stats)
- admin_rag.py (RAG upload, metrics, storage)
- admin_templates.py (Legal templates ingestion, search)
The `router` object is assembled here by including all sub-routers.
Importers that did `from admin_api import router` continue to work.
"""
from fastapi import APIRouter
from admin_nibis import router as _nibis_router
from admin_rag import router as _rag_router
from admin_templates import router as _templates_router
# Re-export internal state for test importers
from admin_nibis import ( # noqa: F401
_ingestion_status,
NiBiSSearchRequest,
search_nibis,
)
from admin_rag import _upload_history # noqa: F401
from admin_templates import _templates_ingestion_status # noqa: F401
# Assemble the combined router.
# All sub-routers use prefix="/api/v1/admin", so include without extra prefix.
router = APIRouter()
router.include_router(_nibis_router)
router.include_router(_rag_router)
router.include_router(_templates_router)
# Backward-compat shim -- module moved to admin/api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("admin.api")
+4 -316
View File
@@ -1,316 +1,4 @@
"""
Admin API - NiBiS Ingestion & Search
Endpoints for NiBiS data discovery, ingestion, search, and statistics.
Extracted from admin_api.py for file-size compliance.
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime
from nibis_ingestion import (
run_ingestion,
discover_documents,
extract_zip_files,
DOCS_BASE_PATH,
)
from qdrant_service import QdrantService, search_nibis_eh, get_qdrant_client
from eh_pipeline import generate_single_embedding
router = APIRouter(prefix="/api/v1/admin", tags=["Admin"])
# Store for background task status
_ingestion_status: Dict = {
"running": False,
"last_run": None,
"last_result": None,
}
# =============================================================================
# Models
# =============================================================================
class IngestionRequest(BaseModel):
ewh_only: bool = True
year_filter: Optional[int] = None
subject_filter: Optional[str] = None
class IngestionStatus(BaseModel):
running: bool
last_run: Optional[str]
documents_indexed: Optional[int]
chunks_created: Optional[int]
errors: Optional[List[str]]
class NiBiSSearchRequest(BaseModel):
query: str
year: Optional[int] = None
subject: Optional[str] = None
niveau: Optional[str] = None
limit: int = 5
class NiBiSSearchResult(BaseModel):
id: str
score: float
text: str
year: Optional[int]
subject: Optional[str]
niveau: Optional[str]
task_number: Optional[int]
class DataSourceStats(BaseModel):
source_dir: str
year: int
document_count: int
subjects: List[str]
# =============================================================================
# Endpoints
# =============================================================================
@router.get("/nibis/status", response_model=IngestionStatus)
async def get_ingestion_status():
"""Get status of NiBiS ingestion pipeline."""
last_result = _ingestion_status.get("last_result") or {}
return IngestionStatus(
running=_ingestion_status["running"],
last_run=_ingestion_status.get("last_run"),
documents_indexed=last_result.get("documents_indexed"),
chunks_created=last_result.get("chunks_created"),
errors=(last_result.get("errors") or [])[:10],
)
@router.post("/nibis/extract-zips")
async def extract_zip_files_endpoint():
"""Extract all ZIP files in za-download directories."""
try:
extracted = extract_zip_files(DOCS_BASE_PATH)
return {
"status": "success",
"extracted_count": len(extracted),
"directories": [str(d) for d in extracted],
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/nibis/discover")
async def discover_nibis_documents(
ewh_only: bool = Query(True, description="Only return Erwartungshorizonte"),
year: Optional[int] = Query(None, description="Filter by year"),
subject: Optional[str] = Query(None, description="Filter by subject"),
):
"""
Discover available NiBiS documents without indexing.
Useful for previewing what will be indexed.
"""
try:
documents = discover_documents(DOCS_BASE_PATH, ewh_only=ewh_only)
# Apply filters
if year:
documents = [d for d in documents if d.year == year]
if subject:
documents = [d for d in documents if subject.lower() in d.subject.lower()]
# Group by year and subject
by_year: Dict[int, int] = {}
by_subject: Dict[str, int] = {}
for doc in documents:
by_year[doc.year] = by_year.get(doc.year, 0) + 1
by_subject[doc.subject] = by_subject.get(doc.subject, 0) + 1
return {
"total_documents": len(documents),
"by_year": dict(sorted(by_year.items())),
"by_subject": dict(sorted(by_subject.items(), key=lambda x: -x[1])),
"sample_documents": [
{
"id": d.id,
"filename": d.raw_filename,
"year": d.year,
"subject": d.subject,
"niveau": d.niveau,
"doc_type": d.doc_type,
}
for d in documents[:20]
],
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/nibis/ingest")
async def start_ingestion(
request: IngestionRequest,
background_tasks: BackgroundTasks,
):
"""
Start NiBiS data ingestion in background.
"""
if _ingestion_status["running"]:
raise HTTPException(
status_code=409,
detail="Ingestion already running. Check /nibis/status for progress."
)
async def run_ingestion_task():
global _ingestion_status
_ingestion_status["running"] = True
_ingestion_status["last_run"] = datetime.now().isoformat()
try:
result = await run_ingestion(
ewh_only=request.ewh_only,
dry_run=False,
year_filter=request.year_filter,
subject_filter=request.subject_filter,
)
_ingestion_status["last_result"] = result
except Exception as e:
_ingestion_status["last_result"] = {"error": str(e), "errors": [str(e)]}
finally:
_ingestion_status["running"] = False
background_tasks.add_task(run_ingestion_task)
return {
"status": "started",
"message": "Ingestion started in background. Check /nibis/status for progress.",
"filters": {
"ewh_only": request.ewh_only,
"year": request.year_filter,
"subject": request.subject_filter,
},
}
@router.post("/nibis/search", response_model=List[NiBiSSearchResult])
async def search_nibis(request: NiBiSSearchRequest):
"""
Semantic search in NiBiS Erwartungshorizonte.
"""
try:
query_embedding = await generate_single_embedding(request.query)
if not query_embedding:
raise HTTPException(status_code=500, detail="Failed to generate embedding")
results = await search_nibis_eh(
query_embedding=query_embedding,
year=request.year,
subject=request.subject,
niveau=request.niveau,
limit=request.limit,
)
return [
NiBiSSearchResult(
id=r["id"],
score=r["score"],
text=r.get("text", "")[:500],
year=r.get("year"),
subject=r.get("subject"),
niveau=r.get("niveau"),
task_number=r.get("task_number"),
)
for r in results
]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/nibis/collections")
async def get_collections_info():
"""Get information about all Qdrant collections."""
try:
client = get_qdrant_client()
collections = client.get_collections().collections
result = []
for c in collections:
try:
info = client.get_collection(c.name)
result.append({
"name": c.name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value,
})
except Exception as e:
result.append({
"name": c.name,
"error": str(e),
})
return {"collections": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/nibis/stats")
async def get_nibis_stats():
"""Get detailed statistics about indexed NiBiS data."""
try:
qdrant = QdrantService()
stats = await qdrant.get_stats("bp_nibis_eh")
if "error" in stats:
return {
"indexed": False,
"message": "NiBiS collection not yet created. Run ingestion first.",
}
client = get_qdrant_client()
scroll_result = client.scroll(
collection_name="bp_nibis_eh",
limit=1000,
with_payload=True,
with_vectors=False,
)
years = set()
subjects = set()
niveaus = set()
for point in scroll_result[0]:
if point.payload:
if "year" in point.payload:
years.add(point.payload["year"])
if "subject" in point.payload:
subjects.add(point.payload["subject"])
if "niveau" in point.payload:
niveaus.add(point.payload["niveau"])
return {
"indexed": True,
"total_chunks": stats.get("points_count", 0),
"years": sorted(list(years)),
"subjects": sorted(list(subjects)),
"niveaus": sorted(list(niveaus)),
}
except Exception as e:
return {
"indexed": False,
"error": str(e),
}
@router.delete("/nibis/collection")
async def delete_nibis_collection():
"""Delete the entire NiBiS collection. WARNING: removes all indexed data!"""
try:
client = get_qdrant_client()
client.delete_collection("bp_nibis_eh")
return {"status": "deleted", "collection": "bp_nibis_eh"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Backward-compat shim -- module moved to admin/nibis.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("admin.nibis")
+4 -281
View File
@@ -1,281 +1,4 @@
"""
Admin API - RAG Upload & Metrics
Endpoints for uploading documents, tracking uploads, RAG metrics,
search feedback, storage stats, and service initialization.
Extracted from admin_api.py for file-size compliance.
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query, UploadFile, File, Form
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime
from pathlib import Path
import zipfile
import tempfile
import os
from nibis_ingestion import run_ingestion, DOCS_BASE_PATH
# Import ingestion status from nibis module for auto-ingest
from admin_nibis import _ingestion_status
# Optional: MinIO and PostgreSQL integrations
try:
from minio_storage import upload_rag_document, get_storage_stats, init_minio_bucket
MINIO_AVAILABLE = True
except ImportError:
MINIO_AVAILABLE = False
try:
from metrics_db import (
init_metrics_tables, store_feedback, log_search, log_upload,
calculate_metrics, get_recent_feedback, get_upload_history
)
METRICS_DB_AVAILABLE = True
except ImportError:
METRICS_DB_AVAILABLE = False
router = APIRouter(prefix="/api/v1/admin", tags=["Admin"])
# Upload directory configuration
RAG_UPLOAD_BASE = Path(os.getenv("RAG_UPLOAD_BASE", str(DOCS_BASE_PATH)))
# Store for upload tracking
_upload_history: List[Dict] = []
class UploadResult(BaseModel):
status: str
files_received: int
pdfs_extracted: int
target_directory: str
errors: List[str]
@router.post("/rag/upload", response_model=UploadResult)
async def upload_rag_documents(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
collection: str = Form(default="bp_nibis_eh"),
year: Optional[int] = Form(default=None),
auto_ingest: bool = Form(default=False),
):
"""
Upload documents for RAG indexing.
Supports:
- ZIP archives (automatically extracted)
- Individual PDF files
"""
errors = []
pdfs_extracted = 0
# Determine target year
target_year = year or datetime.now().year
# Target directory: za-download/YYYY/
target_dir = RAG_UPLOAD_BASE / "za-download" / str(target_year)
target_dir.mkdir(parents=True, exist_ok=True)
try:
filename = file.filename or "upload"
if filename.lower().endswith(".zip"):
# Handle ZIP file
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
with zipfile.ZipFile(tmp_path, 'r') as zf:
for member in zf.namelist():
if member.lower().endswith(".pdf") and not member.startswith("__MACOSX"):
pdf_name = Path(member).name
if pdf_name:
target_path = target_dir / pdf_name
with zf.open(member) as src:
with open(target_path, 'wb') as dst:
dst.write(src.read())
pdfs_extracted += 1
finally:
os.unlink(tmp_path)
elif filename.lower().endswith(".pdf"):
target_path = target_dir / filename
content = await file.read()
with open(target_path, 'wb') as f:
f.write(content)
pdfs_extracted = 1
else:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {filename}. Only .zip and .pdf are allowed."
)
# Track upload in memory
upload_record = {
"timestamp": datetime.now().isoformat(),
"filename": filename,
"collection": collection,
"year": target_year,
"pdfs_extracted": pdfs_extracted,
"target_directory": str(target_dir),
}
_upload_history.append(upload_record)
# Keep only last 100 uploads in memory
if len(_upload_history) > 100:
_upload_history.pop(0)
# Store in PostgreSQL if available
if METRICS_DB_AVAILABLE:
await log_upload(
filename=filename,
collection_name=collection,
year=target_year,
pdfs_extracted=pdfs_extracted,
minio_path=str(target_dir),
)
# Auto-ingest if requested
if auto_ingest and not _ingestion_status["running"]:
async def run_auto_ingest():
global _ingestion_status
_ingestion_status["running"] = True
_ingestion_status["last_run"] = datetime.now().isoformat()
try:
result = await run_ingestion(
ewh_only=True,
dry_run=False,
year_filter=target_year,
)
_ingestion_status["last_result"] = result
except Exception as e:
_ingestion_status["last_result"] = {"error": str(e), "errors": [str(e)]}
finally:
_ingestion_status["running"] = False
background_tasks.add_task(run_auto_ingest)
return UploadResult(
status="success",
files_received=1,
pdfs_extracted=pdfs_extracted,
target_directory=str(target_dir),
errors=errors,
)
except HTTPException:
raise
except Exception as e:
errors.append(str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.get("/rag/upload/history")
async def get_upload_history_endpoint(limit: int = Query(default=20, le=100)):
"""Get recent upload history."""
return {
"uploads": _upload_history[-limit:][::-1],
"total": len(_upload_history),
}
@router.get("/rag/metrics")
async def get_rag_metrics(
collection: Optional[str] = Query(default=None),
days: int = Query(default=7, le=90),
):
"""Get RAG quality metrics."""
if METRICS_DB_AVAILABLE:
metrics = await calculate_metrics(collection_name=collection, days=days)
if metrics.get("connected"):
return metrics
# Fallback: Return placeholder metrics
return {
"precision_at_5": 0.78,
"recall_at_10": 0.85,
"mrr": 0.72,
"avg_latency_ms": 52,
"total_ratings": len(_upload_history),
"error_rate": 0.3,
"score_distribution": {
"0.9+": 23,
"0.7-0.9": 41,
"0.5-0.7": 28,
"<0.5": 8,
},
"note": "Placeholder metrics - PostgreSQL not connected",
"connected": False,
}
@router.post("/rag/search/feedback")
async def submit_search_feedback(
result_id: str = Form(...),
rating: int = Form(..., ge=1, le=5),
notes: Optional[str] = Form(default=None),
query: Optional[str] = Form(default=None),
collection: Optional[str] = Form(default=None),
score: Optional[float] = Form(default=None),
):
"""Submit feedback for a search result."""
feedback_record = {
"timestamp": datetime.now().isoformat(),
"result_id": result_id,
"rating": rating,
"notes": notes,
}
stored = False
if METRICS_DB_AVAILABLE:
stored = await store_feedback(
result_id=result_id,
rating=rating,
query_text=query,
collection_name=collection,
score=score,
notes=notes,
)
return {
"status": "stored" if stored else "received",
"feedback": feedback_record,
"persisted": stored,
}
@router.get("/rag/storage/stats")
async def get_storage_statistics():
"""Get MinIO storage statistics."""
if MINIO_AVAILABLE:
stats = await get_storage_stats()
return stats
return {
"error": "MinIO not available",
"connected": False,
}
@router.post("/rag/init")
async def initialize_rag_services():
"""Initialize RAG services (MinIO bucket, PostgreSQL tables)."""
results = {
"minio": False,
"postgres": False,
}
if MINIO_AVAILABLE:
results["minio"] = await init_minio_bucket()
if METRICS_DB_AVAILABLE:
results["postgres"] = await init_metrics_tables()
return {
"status": "initialized",
"services": results,
}
# Backward-compat shim -- module moved to admin/rag.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("admin.rag")
+4 -389
View File
@@ -1,389 +1,4 @@
"""
Admin API - Legal Templates
Endpoints for legal template ingestion, search, source management,
license info, and collection management.
Extracted from admin_api.py for file-size compliance.
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime
from eh_pipeline import generate_single_embedding
# Import legal templates modules
try:
from legal_templates_ingestion import (
LegalTemplatesIngestion,
LEGAL_TEMPLATES_COLLECTION,
)
from template_sources import (
TEMPLATE_SOURCES,
TEMPLATE_TYPES,
JURISDICTIONS,
LicenseType,
get_enabled_sources,
get_sources_by_priority,
)
from qdrant_service import (
search_legal_templates,
get_legal_templates_stats,
init_legal_templates_collection,
)
LEGAL_TEMPLATES_AVAILABLE = True
except ImportError as e:
print(f"Legal templates module not available: {e}")
LEGAL_TEMPLATES_AVAILABLE = False
router = APIRouter(prefix="/api/v1/admin", tags=["Admin"])
# Store for templates ingestion status
_templates_ingestion_status: Dict = {
"running": False,
"last_run": None,
"current_source": None,
"results": {},
}
class TemplatesSearchRequest(BaseModel):
query: str
template_type: Optional[str] = None
license_types: Optional[List[str]] = None
language: Optional[str] = None
jurisdiction: Optional[str] = None
attribution_required: Optional[bool] = None
limit: int = 10
class TemplatesSearchResult(BaseModel):
id: str
score: float
text: str
document_title: Optional[str]
template_type: Optional[str]
clause_category: Optional[str]
language: Optional[str]
jurisdiction: Optional[str]
license_id: Optional[str]
license_name: Optional[str]
attribution_required: Optional[bool]
attribution_text: Optional[str]
source_name: Optional[str]
source_url: Optional[str]
placeholders: Optional[List[str]]
is_complete_document: Optional[bool]
requires_customization: Optional[bool]
class SourceIngestRequest(BaseModel):
source_name: str
@router.get("/templates/status")
async def get_templates_status():
"""Get status of legal templates collection and ingestion."""
if not LEGAL_TEMPLATES_AVAILABLE:
return {
"available": False,
"error": "Legal templates module not available",
}
try:
stats = await get_legal_templates_stats()
return {
"available": True,
"collection": LEGAL_TEMPLATES_COLLECTION,
"ingestion": {
"running": _templates_ingestion_status["running"],
"last_run": _templates_ingestion_status.get("last_run"),
"current_source": _templates_ingestion_status.get("current_source"),
"results": _templates_ingestion_status.get("results", {}),
},
"stats": stats,
}
except Exception as e:
return {
"available": True,
"error": str(e),
"ingestion": _templates_ingestion_status,
}
@router.get("/templates/sources")
async def get_templates_sources():
"""Get list of all template sources with their configuration."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
sources = []
for source in TEMPLATE_SOURCES:
sources.append({
"name": source.name,
"description": source.description,
"license_type": source.license_type.value,
"license_name": source.license_info.name,
"template_types": source.template_types,
"languages": source.languages,
"jurisdiction": source.jurisdiction,
"repo_url": source.repo_url,
"web_url": source.web_url,
"priority": source.priority,
"enabled": source.enabled,
"attribution_required": source.license_info.attribution_required,
})
return {
"sources": sources,
"total": len(sources),
"enabled": len([s for s in TEMPLATE_SOURCES if s.enabled]),
"template_types": TEMPLATE_TYPES,
"jurisdictions": JURISDICTIONS,
}
@router.get("/templates/licenses")
async def get_templates_licenses():
"""Get license statistics for indexed templates."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
try:
stats = await get_legal_templates_stats()
return {
"licenses": stats.get("licenses", {}),
"total_chunks": stats.get("points_count", 0),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/templates/ingest")
async def start_templates_ingestion(
background_tasks: BackgroundTasks,
max_priority: int = Query(default=3, ge=1, le=5, description="Maximum priority level (1=highest)"),
):
"""
Start legal templates ingestion in background.
Ingests all enabled sources up to the specified priority level.
"""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
if _templates_ingestion_status["running"]:
raise HTTPException(
status_code=409,
detail="Templates ingestion already running. Check /templates/status for progress."
)
async def run_templates_ingestion():
global _templates_ingestion_status
_templates_ingestion_status["running"] = True
_templates_ingestion_status["last_run"] = datetime.now().isoformat()
_templates_ingestion_status["results"] = {}
try:
ingestion = LegalTemplatesIngestion()
sources = get_sources_by_priority(max_priority)
for source in sources:
_templates_ingestion_status["current_source"] = source.name
try:
status = await ingestion.ingest_source(source)
_templates_ingestion_status["results"][source.name] = {
"status": status.status,
"documents_found": status.documents_found,
"chunks_indexed": status.chunks_indexed,
"errors": status.errors[:5] if status.errors else [],
}
except Exception as e:
_templates_ingestion_status["results"][source.name] = {
"status": "failed",
"error": str(e),
}
await ingestion.close()
except Exception as e:
_templates_ingestion_status["results"]["_global_error"] = str(e)
finally:
_templates_ingestion_status["running"] = False
_templates_ingestion_status["current_source"] = None
background_tasks.add_task(run_templates_ingestion)
sources = get_sources_by_priority(max_priority)
return {
"status": "started",
"message": f"Ingesting {len(sources)} sources up to priority {max_priority}",
"sources": [s.name for s in sources],
}
@router.post("/templates/ingest-source")
async def ingest_single_source(
request: SourceIngestRequest,
background_tasks: BackgroundTasks,
):
"""Ingest a single template source by name."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
source = next((s for s in TEMPLATE_SOURCES if s.name == request.source_name), None)
if not source:
raise HTTPException(
status_code=404,
detail=f"Source not found: {request.source_name}. Use /templates/sources to list available sources."
)
if not source.enabled:
raise HTTPException(
status_code=400,
detail=f"Source is disabled: {request.source_name}"
)
if _templates_ingestion_status["running"]:
raise HTTPException(
status_code=409,
detail="Templates ingestion already running."
)
async def run_single_ingestion():
global _templates_ingestion_status
_templates_ingestion_status["running"] = True
_templates_ingestion_status["current_source"] = source.name
_templates_ingestion_status["last_run"] = datetime.now().isoformat()
try:
ingestion = LegalTemplatesIngestion()
status = await ingestion.ingest_source(source)
_templates_ingestion_status["results"][source.name] = {
"status": status.status,
"documents_found": status.documents_found,
"chunks_indexed": status.chunks_indexed,
"errors": status.errors[:5] if status.errors else [],
}
await ingestion.close()
except Exception as e:
_templates_ingestion_status["results"][source.name] = {
"status": "failed",
"error": str(e),
}
finally:
_templates_ingestion_status["running"] = False
_templates_ingestion_status["current_source"] = None
background_tasks.add_task(run_single_ingestion)
return {
"status": "started",
"source": source.name,
"license": source.license_type.value,
"template_types": source.template_types,
}
@router.post("/templates/search", response_model=List[TemplatesSearchResult])
async def search_templates(request: TemplatesSearchRequest):
"""Semantic search in legal templates collection."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
try:
query_embedding = await generate_single_embedding(request.query)
if not query_embedding:
raise HTTPException(status_code=500, detail="Failed to generate embedding")
results = await search_legal_templates(
query_embedding=query_embedding,
template_type=request.template_type,
license_types=request.license_types,
language=request.language,
jurisdiction=request.jurisdiction,
attribution_required=request.attribution_required,
limit=request.limit,
)
return [
TemplatesSearchResult(
id=r["id"],
score=r["score"],
text=r.get("text", "")[:1000],
document_title=r.get("document_title"),
template_type=r.get("template_type"),
clause_category=r.get("clause_category"),
language=r.get("language"),
jurisdiction=r.get("jurisdiction"),
license_id=r.get("license_id"),
license_name=r.get("license_name"),
attribution_required=r.get("attribution_required"),
attribution_text=r.get("attribution_text"),
source_name=r.get("source_name"),
source_url=r.get("source_url"),
placeholders=r.get("placeholders"),
is_complete_document=r.get("is_complete_document"),
requires_customization=r.get("requires_customization"),
)
for r in results
]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/templates/reset")
async def reset_templates_collection():
"""Delete and recreate the legal templates collection."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
if _templates_ingestion_status["running"]:
raise HTTPException(
status_code=409,
detail="Cannot reset while ingestion is running"
)
try:
ingestion = LegalTemplatesIngestion()
ingestion.reset_collection()
await ingestion.close()
_templates_ingestion_status["results"] = {}
return {
"status": "reset",
"collection": LEGAL_TEMPLATES_COLLECTION,
"message": "Collection deleted and recreated. Run ingestion to populate.",
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/templates/source/{source_name}")
async def delete_templates_source(source_name: str):
"""Delete all templates from a specific source."""
if not LEGAL_TEMPLATES_AVAILABLE:
raise HTTPException(status_code=503, detail="Legal templates module not available")
try:
from qdrant_service import delete_legal_templates_by_source
count = await delete_legal_templates_by_source(source_name)
if source_name in _templates_ingestion_status.get("results", {}):
del _templates_ingestion_status["results"][source_name]
return {
"status": "deleted",
"source": source_name,
"chunks_deleted": count,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Backward-compat shim -- module moved to admin/templates.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("admin.templates")
@@ -0,0 +1,6 @@
"""
compliance package — compliance pipeline, RBAC/ABAC policy engine.
Backward-compatible re-exports: consumers can still use
``from compliance_models import ...`` etc. via the shim files in backend/.
"""
@@ -0,0 +1,200 @@
"""
Compliance Extraction & Generation.
Functions for extracting checkpoints from legal text chunks,
generating controls, and creating remediation measures.
"""
import re
import hashlib
import logging
from typing import Dict, List, Optional
from .models import Checkpoint, Control, Measure
logger = logging.getLogger(__name__)
def extract_checkpoints_from_chunk(chunk_text: str, payload: Dict) -> List[Checkpoint]:
"""
Extract checkpoints/requirements from a chunk of text.
Uses pattern matching to find requirement-like statements.
"""
checkpoints = []
regulation_code = payload.get("regulation_code", "UNKNOWN")
regulation_name = payload.get("regulation_name", "Unknown")
source_url = payload.get("source_url", "")
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
# Patterns for different requirement types
patterns = [
# BSI-TR patterns
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
# Article patterns (GDPR, AI Act, etc.)
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-\u2013:]\s*(.+?)(?=\n|$)', 'article'),
# Numbered requirements
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
# "Der Verantwortliche muss" patterns
(r'(?:Der Verantwortliche|Die Aufsichtsbeh\u00f6rde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
# "Es ist erforderlich" patterns
(r'(?:Es ist erforderlich|Es muss gew\u00e4hrleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
]
for pattern, pattern_type in patterns:
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
for match in matches:
if pattern_type == 'bsi_requirement':
req_id = match.group(1)
description = match.group(2).strip()
title = req_id
elif pattern_type == 'article':
article_num = match.group(1)
paragraph = match.group(2) or ""
title_text = match.group(3).strip()
req_id = f"{regulation_code}-Art{article_num}"
if paragraph:
req_id += f"-{paragraph}"
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
description = title_text
elif pattern_type == 'numbered':
num = match.group(1)
description = match.group(2).strip()
req_id = f"{regulation_code}-{num}"
title = f"Anforderung {num}"
else:
# Generic requirement
description = match.group(0).strip()
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
title = description[:50] + "..." if len(description) > 50 else description
# Skip very short matches
if len(description) < 20:
continue
checkpoint = Checkpoint(
id=req_id,
regulation_code=regulation_code,
regulation_name=regulation_name,
article=title if 'Art' in title else None,
title=title,
description=description[:500],
original_text=description,
chunk_id=chunk_id,
source_url=source_url
)
checkpoints.append(checkpoint)
return checkpoints
def generate_control_for_checkpoints(
checkpoints: List[Checkpoint],
domain_counts: Dict[str, int],
) -> Optional[Control]:
"""
Generate a control that covers the given checkpoints.
This is a simplified version - in production this would use the AI assistant.
"""
if not checkpoints:
return None
# Group by regulation
regulation = checkpoints[0].regulation_code
# Determine domain based on content
all_text = " ".join([cp.description for cp in checkpoints]).lower()
domain = "gov" # Default
if any(kw in all_text for kw in ["verschl\u00fcssel", "krypto", "encrypt", "hash"]):
domain = "crypto"
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
domain = "iam"
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
domain = "priv"
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
domain = "sdlc"
elif any(kw in all_text for kw in ["\u00fcberwach", "monitor", "log", "audit"]):
domain = "aud"
elif any(kw in all_text for kw in ["ki", "k\u00fcnstlich", "ai", "machine learning", "model"]):
domain = "ai"
elif any(kw in all_text for kw in ["betrieb", "operation", "verf\u00fcgbar", "backup"]):
domain = "ops"
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
domain = "cra"
# Generate control ID
domain_count = domain_counts.get(domain, 0) + 1
control_id = f"{domain.upper()}-{domain_count:03d}"
# Create title from first checkpoint
title = checkpoints[0].title
if len(title) > 100:
title = title[:97] + "..."
# Create description
description = f"Control f\u00fcr {regulation}: " + checkpoints[0].description[:200]
# Pass criteria
pass_criteria = f"Alle {len(checkpoints)} zugeh\u00f6rigen Anforderungen sind erf\u00fcllt und dokumentiert."
# Implementation guidance
guidance = f"Implementiere Ma\u00dfnahmen zur Erf\u00fcllung der Anforderungen aus {regulation}. "
guidance += f"Dokumentiere die Umsetzung und f\u00fchre regelm\u00e4\u00dfige Reviews durch."
# Determine if automated
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
control = Control(
id=control_id,
domain=domain,
title=title,
description=description,
checkpoints=[cp.id for cp in checkpoints],
pass_criteria=pass_criteria,
implementation_guidance=guidance,
is_automated=is_automated,
automation_tool="CI/CD Pipeline" if is_automated else None,
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
)
return control
def generate_measure_for_control(control: Control) -> Measure:
"""Generate a remediation measure for a control."""
measure_id = f"M-{control.id}"
# Determine deadline based on priority
deadline_days = {
"critical": 30,
"high": 60,
"medium": 90,
"low": 180
}.get(control.priority, 90)
# Determine responsible team
responsible = {
"priv": "Datenschutzbeauftragter",
"iam": "IT-Security Team",
"sdlc": "Entwicklungsteam",
"crypto": "IT-Security Team",
"ops": "Operations Team",
"aud": "Compliance Team",
"ai": "AI/ML Team",
"cra": "IT-Security Team",
"gov": "Management"
}.get(control.domain, "Compliance Team")
measure = Measure(
id=measure_id,
control_id=control.id,
title=f"Umsetzung: {control.title[:50]}",
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
responsible=responsible,
deadline_days=deadline_days,
status="pending"
)
return measure
@@ -0,0 +1,65 @@
#!/usr/bin/env python3
"""
Full Compliance Pipeline for Legal Corpus — Barrel Re-export.
Split into submodules:
- compliance_models.py — Dataclasses (Checkpoint, Control, Measure)
- compliance_extraction.py — Pattern extraction & control/measure generation
- compliance_pipeline.py — Pipeline phases & orchestrator
Run on Mac Mini:
nohup python full_compliance_pipeline.py > /tmp/compliance_pipeline.log 2>&1 &
"""
import asyncio
import logging
import sys
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('/tmp/compliance_pipeline.log')
]
)
# Re-export all public symbols
from .models import Checkpoint, Control, Measure
from .extraction import (
extract_checkpoints_from_chunk,
generate_control_for_checkpoints,
generate_measure_for_control,
)
from .pipeline import CompliancePipeline
__all__ = [
"Checkpoint",
"Control",
"Measure",
"extract_checkpoints_from_chunk",
"generate_control_for_checkpoints",
"generate_measure_for_control",
"CompliancePipeline",
]
async def main():
import argparse
parser = argparse.ArgumentParser(description="Run the compliance pipeline")
parser.add_argument("--force-reindex", action="store_true",
help="Force re-ingestion of all documents")
parser.add_argument("--skip-ingestion", action="store_true",
help="Skip ingestion phase, use existing chunks")
args = parser.parse_args()
pipeline = CompliancePipeline()
await pipeline.run_full_pipeline(
force_reindex=args.force_reindex,
skip_ingestion=args.skip_ingestion
)
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,49 @@
"""
Compliance Pipeline Data Models.
Dataclasses for checkpoints, controls, and measures.
"""
from typing import Optional, List
from dataclasses import dataclass
@dataclass
class Checkpoint:
"""A requirement/checkpoint extracted from legal text."""
id: str
regulation_code: str
regulation_name: str
article: Optional[str]
title: str
description: str
original_text: str
chunk_id: str
source_url: str
@dataclass
class Control:
"""A control derived from checkpoints."""
id: str
domain: str
title: str
description: str
checkpoints: List[str] # List of checkpoint IDs
pass_criteria: str
implementation_guidance: str
is_automated: bool
automation_tool: Optional[str]
priority: str
@dataclass
class Measure:
"""A remediation measure for a control."""
id: str
control_id: str
title: str
description: str
responsible: str
deadline_days: int
status: str
@@ -0,0 +1,441 @@
"""
Compliance Pipeline Execution.
Pipeline phases (ingestion, extraction, control generation, measures)
and orchestration logic.
"""
import asyncio
import json
import logging
import os
import sys
import time
from datetime import datetime
from typing import Dict, List, Any
from dataclasses import asdict
from .models import Checkpoint, Control, Measure
from .extraction import (
extract_checkpoints_from_chunk,
generate_control_for_checkpoints,
generate_measure_for_control,
)
logger = logging.getLogger(__name__)
# Import checkpoint manager
try:
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
except ImportError:
logger.warning("Checkpoint manager not available, running without checkpoints")
CheckpointManager = None
EXPECTED_VALUES = {}
ValidationStatus = None
# Set environment variables for Docker network
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
os.environ["QDRANT_HOST"] = "qdrant"
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
# Try to import from klausur-service
try:
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
except ImportError:
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
sys.exit(1)
class CompliancePipeline:
"""Handles the full compliance pipeline."""
def __init__(self):
# Support both QDRANT_URL and QDRANT_HOST/PORT
qdrant_url = os.getenv("QDRANT_URL", "")
if qdrant_url:
from urllib.parse import urlparse
parsed = urlparse(qdrant_url)
qdrant_host = parsed.hostname or "qdrant"
qdrant_port = parsed.port or 6333
else:
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
qdrant_port = 6333
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
self.checkpoints: List[Checkpoint] = []
self.controls: List[Control] = []
self.measures: List[Measure] = []
self.stats = {
"chunks_processed": 0,
"checkpoints_extracted": 0,
"controls_created": 0,
"measures_defined": 0,
"by_regulation": {},
"by_domain": {},
}
# Initialize checkpoint manager
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
"""Phase 1: Ingest documents (incremental - only missing ones)."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
ingestion = LegalCorpusIngestion()
try:
# Check existing chunks per regulation
existing_chunks = {}
try:
for regulation in REGULATIONS:
count_result = self.qdrant.count(
collection_name=LEGAL_CORPUS_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
)
)
existing_chunks[regulation.code] = count_result.count
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
except Exception as e:
logger.warning(f"Could not check existing chunks: {e}")
# Determine which regulations need ingestion
regulations_to_ingest = []
for regulation in REGULATIONS:
existing = existing_chunks.get(regulation.code, 0)
if force_reindex or existing == 0:
regulations_to_ingest.append(regulation)
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
else:
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
self.stats["by_regulation"][regulation.code] = existing
if not regulations_to_ingest:
logger.info("All regulations already indexed. Skipping ingestion phase.")
total_chunks = sum(existing_chunks.values())
self.stats["chunks_processed"] = total_chunks
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("skipped", True)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
# Ingest only missing regulations
total_chunks = sum(existing_chunks.values())
for i, regulation in enumerate(regulations_to_ingest, 1):
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
try:
count = await ingestion.ingest_regulation(regulation)
total_chunks += count
self.stats["by_regulation"][regulation.code] = count
logger.info(f" -> {count} chunks")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
except Exception as e:
logger.error(f" -> FAILED: {e}")
self.stats["by_regulation"][regulation.code] = 0
self.stats["chunks_processed"] = total_chunks
logger.info(f"\nTotal chunks in collection: {total_chunks}")
# Validate ingestion results
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
expected = EXPECTED_VALUES.get("ingestion", {})
self.checkpoint_mgr.validate(
"total_chunks",
expected=expected.get("total_chunks", 8000),
actual=total_chunks,
min_value=expected.get("min_chunks", 7000)
)
reg_expected = expected.get("regulations", {})
for reg_code, reg_exp in reg_expected.items():
actual = self.stats["by_regulation"].get(reg_code, 0)
self.checkpoint_mgr.validate(
f"chunks_{reg_code}",
expected=reg_exp.get("expected", 0),
actual=actual,
min_value=reg_exp.get("min", 0)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
finally:
await ingestion.close()
async def run_extraction_phase(self) -> int:
"""Phase 2: Extract checkpoints from chunks."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
try:
offset = None
total_checkpoints = 0
while True:
result = self.qdrant.scroll(
collection_name=LEGAL_CORPUS_COLLECTION,
limit=100,
offset=offset,
with_payload=True,
with_vectors=False
)
points, next_offset = result
if not points:
break
for point in points:
payload = point.payload
text = payload.get("text", "")
cps = extract_checkpoints_from_chunk(text, payload)
self.checkpoints.extend(cps)
total_checkpoints += len(cps)
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
if next_offset is None:
break
offset = next_offset
self.stats["checkpoints_extracted"] = len(self.checkpoints)
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
by_reg = {}
for cp in self.checkpoints:
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
for reg, count in sorted(by_reg.items()):
logger.info(f" {reg}: {count} checkpoints")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
expected = EXPECTED_VALUES.get("extraction", {})
self.checkpoint_mgr.validate(
"total_checkpoints",
expected=expected.get("total_checkpoints", 3500),
actual=len(self.checkpoints),
min_value=expected.get("min_checkpoints", 3000)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.checkpoints)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_control_generation_phase(self) -> int:
"""Phase 3: Generate controls from checkpoints."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 3: CONTROL GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
try:
# Group checkpoints by regulation
by_regulation: Dict[str, List[Checkpoint]] = {}
for cp in self.checkpoints:
reg = cp.regulation_code
if reg not in by_regulation:
by_regulation[reg] = []
by_regulation[reg].append(cp)
# Generate controls per regulation (group every 3-5 checkpoints)
for regulation, checkpoints in by_regulation.items():
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
batch_size = 4
for i in range(0, len(checkpoints), batch_size):
batch = checkpoints[i:i + batch_size]
control = generate_control_for_checkpoints(batch, self.stats.get("by_domain", {}))
if control:
self.controls.append(control)
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
self.stats["controls_created"] = len(self.controls)
logger.info(f"\nTotal controls created: {len(self.controls)}")
for domain, count in sorted(self.stats["by_domain"].items()):
logger.info(f" {domain}: {count} controls")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
expected = EXPECTED_VALUES.get("controls", {})
self.checkpoint_mgr.validate(
"total_controls",
expected=expected.get("total_controls", 900),
actual=len(self.controls),
min_value=expected.get("min_controls", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.controls)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_measure_generation_phase(self) -> int:
"""Phase 4: Generate measures for controls."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 4: MEASURE GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
try:
for control in self.controls:
measure = generate_measure_for_control(control)
self.measures.append(measure)
self.stats["measures_defined"] = len(self.measures)
logger.info(f"\nTotal measures defined: {len(self.measures)}")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
expected = EXPECTED_VALUES.get("measures", {})
self.checkpoint_mgr.validate(
"total_measures",
expected=expected.get("total_measures", 900),
actual=len(self.measures),
min_value=expected.get("min_measures", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.measures)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
def save_results(self, output_dir: str = "/tmp/compliance_output"):
"""Save results to JSON files."""
logger.info("\n" + "=" * 60)
logger.info("SAVING RESULTS")
logger.info("=" * 60)
os.makedirs(output_dir, exist_ok=True)
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
with open(checkpoints_file, "w") as f:
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
controls_file = os.path.join(output_dir, "controls.json")
with open(controls_file, "w") as f:
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
measures_file = os.path.join(output_dir, "measures.json")
with open(measures_file, "w") as f:
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
stats_file = os.path.join(output_dir, "statistics.json")
self.stats["generated_at"] = datetime.now().isoformat()
with open(stats_file, "w") as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
logger.info(f"Saved statistics to {stats_file}")
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
"""Run the complete pipeline.
Args:
force_reindex: If True, re-ingest all documents even if they exist
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
"""
start_time = time.time()
logger.info("=" * 60)
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
logger.info(f"Started at: {datetime.now().isoformat()}")
logger.info(f"Force reindex: {force_reindex}")
logger.info(f"Skip ingestion: {skip_ingestion}")
if self.checkpoint_mgr:
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
logger.info("=" * 60)
try:
if skip_ingestion:
logger.info("Skipping ingestion phase as requested...")
try:
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
self.stats["chunks_processed"] = collection_info.points_count
except Exception:
self.stats["chunks_processed"] = 0
else:
await self.run_ingestion_phase(force_reindex=force_reindex)
await self.run_extraction_phase()
await self.run_control_generation_phase()
await self.run_measure_generation_phase()
self.save_results()
elapsed = time.time() - start_time
logger.info("\n" + "=" * 60)
logger.info("PIPELINE COMPLETE")
logger.info("=" * 60)
logger.info(f"Duration: {elapsed:.1f} seconds")
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
logger.info(f"Controls created: {self.stats['controls_created']}")
logger.info(f"Measures defined: {self.stats['measures_defined']}")
logger.info(f"\nResults saved to: /tmp/compliance_output/")
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.complete_pipeline({
"duration_seconds": elapsed,
"chunks_processed": self.stats['chunks_processed'],
"checkpoints_extracted": self.stats['checkpoints_extracted'],
"controls_created": self.stats['controls_created'],
"measures_defined": self.stats['measures_defined'],
"by_regulation": self.stats['by_regulation'],
"by_domain": self.stats['by_domain'],
})
except Exception as e:
logger.error(f"Pipeline failed: {e}")
if self.checkpoint_mgr:
self.checkpoint_mgr.state.status = "failed"
self.checkpoint_mgr._save()
raise
@@ -0,0 +1,38 @@
"""
RBAC/ABAC Policy System for Klausur-Service (barrel re-export)
This module was split into:
- rbac_types.py (Enums, data structures)
- rbac_permissions.py (Permission matrix)
- rbac_engine.py (PolicyEngine, default policies, API guards)
All public symbols are re-exported here for backwards compatibility.
"""
# Types and enums
from .rbac_types import ( # noqa: F401
Role,
Action,
ResourceType,
ZKVisibilityMode,
EHVisibilityMode,
VerfahrenType,
PolicySet,
RoleAssignment,
KeyShare,
Tenant,
Namespace,
ExamPackage,
)
# Permission matrix
from .rbac_permissions import DEFAULT_PERMISSIONS # noqa: F401
# Engine, policies, guards
from .rbac_engine import ( # noqa: F401
PolicyEngine,
create_default_policy_sets,
get_policy_engine,
require_permission,
require_role,
)
@@ -0,0 +1,498 @@
"""
RBAC Policy Engine
Core engine for RBAC/ABAC permission checks,
role assignments, key shares, and default policies.
Extracted from rbac.py for file-size compliance.
"""
from typing import Optional, List, Dict, Set
from datetime import datetime, timezone
import uuid
from functools import wraps
from fastapi import HTTPException, Request
from .rbac_types import (
Role,
Action,
ResourceType,
ZKVisibilityMode,
PolicySet,
RoleAssignment,
KeyShare,
)
from .rbac_permissions import DEFAULT_PERMISSIONS
# =============================================
# POLICY ENGINE
# =============================================
class PolicyEngine:
"""
Engine fuer RBAC/ABAC Entscheidungen.
Prueft:
1. Basis-Rollenberechtigung (RBAC)
2. Policy-Einschraenkungen (ABAC)
3. Key Share Berechtigungen
"""
def __init__(self):
self.policy_sets: Dict[str, PolicySet] = {}
self.role_assignments: Dict[str, List[RoleAssignment]] = {} # user_id -> assignments
self.key_shares: Dict[str, List[KeyShare]] = {} # user_id -> shares
def register_policy_set(self, policy: PolicySet):
"""Registriere ein Policy Set."""
self.policy_sets[policy.id] = policy
def get_policy_for_context(
self,
bundesland: str,
jahr: int,
fach: Optional[str] = None,
verfahren: str = "abitur"
) -> Optional[PolicySet]:
"""Finde das passende Policy Set fuer einen Kontext."""
# Exakte Uebereinstimmung
for policy in self.policy_sets.values():
if (policy.bundesland == bundesland and
policy.jahr == jahr and
policy.verfahren == verfahren):
if policy.fach is None or policy.fach == fach:
return policy
# Fallback: Default Policy
for policy in self.policy_sets.values():
if policy.bundesland == "DEFAULT":
return policy
return None
def assign_role(
self,
user_id: str,
role: Role,
resource_type: ResourceType,
resource_id: str,
granted_by: str,
tenant_id: Optional[str] = None,
namespace_id: Optional[str] = None,
valid_to: Optional[datetime] = None
) -> RoleAssignment:
"""Weise einem User eine Rolle zu."""
assignment = RoleAssignment(
id=str(uuid.uuid4()),
user_id=user_id,
role=role,
resource_type=resource_type,
resource_id=resource_id,
tenant_id=tenant_id,
namespace_id=namespace_id,
granted_by=granted_by,
valid_to=valid_to
)
if user_id not in self.role_assignments:
self.role_assignments[user_id] = []
self.role_assignments[user_id].append(assignment)
return assignment
def revoke_role(self, assignment_id: str, revoked_by: str) -> bool:
"""Widerrufe eine Rollenzuweisung."""
for user_assignments in self.role_assignments.values():
for assignment in user_assignments:
if assignment.id == assignment_id:
assignment.revoked_at = datetime.now(timezone.utc)
return True
return False
def get_user_roles(
self,
user_id: str,
resource_type: Optional[ResourceType] = None,
resource_id: Optional[str] = None
) -> List[Role]:
"""Hole alle aktiven Rollen eines Users."""
assignments = self.role_assignments.get(user_id, [])
roles = []
for assignment in assignments:
if not assignment.is_active():
continue
if resource_type and assignment.resource_type != resource_type:
continue
if resource_id and assignment.resource_id != resource_id:
continue
roles.append(assignment.role)
return list(set(roles))
def create_key_share(
self,
user_id: str,
package_id: str,
permissions: Set[str],
granted_by: str,
scope: str = "full",
invite_token: Optional[str] = None
) -> KeyShare:
"""Erstelle einen Key Share."""
share = KeyShare(
id=str(uuid.uuid4()),
user_id=user_id,
package_id=package_id,
permissions=permissions,
scope=scope,
granted_by=granted_by,
invite_token=invite_token
)
if user_id not in self.key_shares:
self.key_shares[user_id] = []
self.key_shares[user_id].append(share)
return share
def accept_key_share(self, share_id: str, token: str) -> bool:
"""Akzeptiere einen Key Share via Invite Token."""
for user_shares in self.key_shares.values():
for share in user_shares:
if share.id == share_id and share.invite_token == token:
share.accepted_at = datetime.now(timezone.utc)
return True
return False
def revoke_key_share(self, share_id: str, revoked_by: str) -> bool:
"""Widerrufe einen Key Share."""
for user_shares in self.key_shares.values():
for share in user_shares:
if share.id == share_id:
share.revoked_at = datetime.now(timezone.utc)
share.revoked_by = revoked_by
return True
return False
def check_permission(
self,
user_id: str,
action: Action,
resource_type: ResourceType,
resource_id: str,
policy: Optional[PolicySet] = None,
package_id: Optional[str] = None
) -> bool:
"""
Pruefe ob ein User eine Aktion ausfuehren darf.
Prueft:
1. Basis-RBAC
2. Policy-Einschraenkungen
3. Key Share (falls package_id angegeben)
"""
# 1. Hole aktive Rollen
roles = self.get_user_roles(user_id, resource_type, resource_id)
if not roles:
return False
# 2. Pruefe Basis-RBAC
has_permission = False
for role in roles:
role_permissions = DEFAULT_PERMISSIONS.get(role, {})
resource_permissions = role_permissions.get(resource_type, set())
if action in resource_permissions:
has_permission = True
break
if not has_permission:
return False
# 3. Pruefe Policy-Einschraenkungen
if policy:
# ZK Visibility Mode
if Role.ZWEITKORREKTOR in roles:
if policy.zk_visibility_mode == ZKVisibilityMode.BLIND:
# Blind: ZK darf EK-Outputs nicht sehen
if resource_type in [ResourceType.EVALUATION, ResourceType.REPORT, ResourceType.GRADE_DECISION]:
if action == Action.READ:
# Pruefe ob es EK-Outputs sind (muesste ueber Metadaten geprueft werden)
pass # Implementierung abhaengig von Datenmodell
elif policy.zk_visibility_mode == ZKVisibilityMode.SEMI:
# Semi: ZK sieht Annotationen, aber keine Note
if resource_type == ResourceType.GRADE_DECISION and action == Action.READ:
return False
# 4. Pruefe Key Share (falls Package-basiert)
if package_id:
user_shares = self.key_shares.get(user_id, [])
has_key_share = any(
share.package_id == package_id and share.is_active()
for share in user_shares
)
if not has_key_share:
return False
return True
def get_allowed_actions(
self,
user_id: str,
resource_type: ResourceType,
resource_id: str,
policy: Optional[PolicySet] = None
) -> Set[Action]:
"""Hole alle erlaubten Aktionen fuer einen User auf einer Ressource."""
roles = self.get_user_roles(user_id, resource_type, resource_id)
allowed = set()
for role in roles:
role_permissions = DEFAULT_PERMISSIONS.get(role, {})
resource_permissions = role_permissions.get(resource_type, set())
allowed.update(resource_permissions)
# Policy-Einschraenkungen anwenden
if policy and Role.ZWEITKORREKTOR in roles:
if policy.zk_visibility_mode == ZKVisibilityMode.BLIND:
# Entferne READ fuer bestimmte Ressourcen
pass # Detailimplementierung
return allowed
# =============================================
# DEFAULT POLICY SETS (alle Bundeslaender)
# =============================================
def create_default_policy_sets() -> List[PolicySet]:
"""
Erstelle Default Policy Sets fuer alle Bundeslaender.
Diese koennen spaeter pro Land verfeinert werden.
"""
bundeslaender = [
"baden-wuerttemberg", "bayern", "berlin", "brandenburg",
"bremen", "hamburg", "hessen", "mecklenburg-vorpommern",
"niedersachsen", "nordrhein-westfalen", "rheinland-pfalz",
"saarland", "sachsen", "sachsen-anhalt", "schleswig-holstein",
"thueringen"
]
policies = []
# Default Policy (Fallback)
policies.append(PolicySet(
id="DEFAULT-2025",
bundesland="DEFAULT",
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.FULL,
eh_visibility_mode=PolicySet.__dataclass_fields__["eh_visibility_mode"].default,
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz"
))
# Niedersachsen (Beispiel mit spezifischen Anpassungen)
policies.append(PolicySet(
id="NI-2025-ABITUR",
bundesland="niedersachsen",
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.FULL, # In NI sieht ZK alles
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz",
export_template_id="niedersachsen-abitur"
))
# Bayern (Beispiel mit SEMI visibility)
policies.append(PolicySet(
id="BY-2025-ABITUR",
bundesland="bayern",
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.SEMI, # ZK sieht Annotationen, nicht Note
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz",
export_template_id="bayern-abitur"
))
# NRW (Beispiel)
policies.append(PolicySet(
id="NW-2025-ABITUR",
bundesland="nordrhein-westfalen",
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.FULL,
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz",
export_template_id="nrw-abitur"
))
# Generiere Basis-Policies fuer alle anderen Bundeslaender
for bl in bundeslaender:
if bl not in ["niedersachsen", "bayern", "nordrhein-westfalen"]:
policies.append(PolicySet(
id=f"{bl[:2].upper()}-2025-ABITUR",
bundesland=bl,
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.FULL,
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz"
))
return policies
# =============================================
# GLOBAL POLICY ENGINE INSTANCE
# =============================================
# Singleton Policy Engine
_policy_engine: Optional[PolicyEngine] = None
def get_policy_engine() -> PolicyEngine:
"""Hole die globale Policy Engine Instanz."""
global _policy_engine
if _policy_engine is None:
_policy_engine = PolicyEngine()
# Registriere Default Policies
for policy in create_default_policy_sets():
_policy_engine.register_policy_set(policy)
return _policy_engine
# =============================================
# API GUARDS (Decorators fuer FastAPI)
# =============================================
def require_permission(
action: Action,
resource_type: ResourceType,
resource_id_param: str = "resource_id"
):
"""
Decorator fuer FastAPI Endpoints.
Prueft ob der aktuelle User die angegebene Berechtigung hat.
Usage:
@app.get("/api/v1/packages/{package_id}")
@require_permission(Action.READ, ResourceType.EXAM_PACKAGE, "package_id")
async def get_package(package_id: str, request: Request):
...
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
request = kwargs.get('request')
if not request:
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request:
raise HTTPException(status_code=500, detail="Request not found")
# User aus Token holen
user = getattr(request.state, 'user', None)
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user.get('user_id')
resource_id = kwargs.get(resource_id_param)
# Policy Engine pruefen
engine = get_policy_engine()
# Optional: Policy aus Kontext laden
policy = None
bundesland = user.get('bundesland')
if bundesland:
policy = engine.get_policy_for_context(bundesland, 2025)
if not engine.check_permission(
user_id=user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
policy=policy
):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {action.value} on {resource_type.value}"
)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_role(role: Role):
"""
Decorator der prueft ob User eine bestimmte Rolle hat.
Usage:
@app.post("/api/v1/eh/publish")
@require_role(Role.LAND_ADMIN)
async def publish_eh(request: Request):
...
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
request = kwargs.get('request')
if not request:
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request:
raise HTTPException(status_code=500, detail="Request not found")
user = getattr(request.state, 'user', None)
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user.get('user_id')
engine = get_policy_engine()
user_roles = engine.get_user_roles(user_id)
if role not in user_roles:
raise HTTPException(
status_code=403,
detail=f"Role required: {role.value}"
)
return await func(*args, **kwargs)
return wrapper
return decorator
@@ -0,0 +1,221 @@
"""
RBAC Permission Matrix
Default role-to-resource permission mappings for
Klausur-Korrektur and Zeugnis workflows.
Extracted from rbac.py for file-size compliance.
"""
from typing import Dict, Set
from .rbac_types import Role, Action, ResourceType
# =============================================
# RBAC PERMISSION MATRIX
# =============================================
# Standard-Berechtigungsmatrix (kann durch Policies ueberschrieben werden)
DEFAULT_PERMISSIONS: Dict[Role, Dict[ResourceType, Set[Action]]] = {
# Erstkorrektor
Role.ERSTKORREKTOR: {
ResourceType.EXAM_PACKAGE: {Action.READ, Action.UPDATE, Action.SHARE_KEY, Action.LOCK},
ResourceType.STUDENT_WORK: {Action.READ, Action.UPDATE},
ResourceType.EH_DOCUMENT: {Action.READ, Action.UPLOAD, Action.UPDATE},
ResourceType.RUBRIC: {Action.READ, Action.UPDATE},
ResourceType.ANNOTATION: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.EVALUATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.REPORT: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.GRADE_DECISION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Zweitkorrektor (Standard: FULL visibility)
Role.ZWEITKORREKTOR: {
ResourceType.EXAM_PACKAGE: {Action.READ},
ResourceType.STUDENT_WORK: {Action.READ, Action.UPDATE},
ResourceType.EH_DOCUMENT: {Action.READ},
ResourceType.RUBRIC: {Action.READ},
ResourceType.ANNOTATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.EVALUATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.REPORT: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.GRADE_DECISION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.EXPORT: {Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Drittkorrektor
Role.DRITTKORREKTOR: {
ResourceType.EXAM_PACKAGE: {Action.READ},
ResourceType.STUDENT_WORK: {Action.READ, Action.UPDATE},
ResourceType.EH_DOCUMENT: {Action.READ},
ResourceType.RUBRIC: {Action.READ},
ResourceType.ANNOTATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.EVALUATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.REPORT: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.GRADE_DECISION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Fachvorsitz
Role.FACHVORSITZ: {
ResourceType.TENANT: {Action.READ},
ResourceType.NAMESPACE: {Action.READ, Action.UPDATE},
ResourceType.EXAM_PACKAGE: {Action.READ, Action.UPDATE, Action.LOCK, Action.UNLOCK, Action.SIGN_OFF},
ResourceType.STUDENT_WORK: {Action.READ, Action.UPDATE},
ResourceType.EH_DOCUMENT: {Action.READ, Action.UPLOAD, Action.UPDATE},
ResourceType.RUBRIC: {Action.READ, Action.UPDATE},
ResourceType.ANNOTATION: {Action.READ, Action.UPDATE},
ResourceType.EVALUATION: {Action.READ, Action.UPDATE},
ResourceType.REPORT: {Action.READ, Action.UPDATE},
ResourceType.GRADE_DECISION: {Action.READ, Action.UPDATE, Action.SIGN_OFF},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Pruefungsvorsitz
Role.PRUEFUNGSVORSITZ: {
ResourceType.TENANT: {Action.READ},
ResourceType.NAMESPACE: {Action.READ, Action.CREATE},
ResourceType.EXAM_PACKAGE: {Action.READ, Action.SIGN_OFF},
ResourceType.STUDENT_WORK: {Action.READ},
ResourceType.EH_DOCUMENT: {Action.READ},
ResourceType.GRADE_DECISION: {Action.READ, Action.SIGN_OFF},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Schul-Admin
Role.SCHUL_ADMIN: {
ResourceType.TENANT: {Action.READ, Action.UPDATE},
ResourceType.NAMESPACE: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.EXAM_PACKAGE: {Action.CREATE, Action.READ, Action.DELETE, Action.ASSIGN_ROLE},
ResourceType.EH_DOCUMENT: {Action.READ, Action.UPLOAD, Action.DELETE},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Land-Admin (Behoerde)
Role.LAND_ADMIN: {
ResourceType.TENANT: {Action.READ},
ResourceType.EH_DOCUMENT: {Action.READ, Action.UPLOAD, Action.UPDATE, Action.DELETE, Action.PUBLISH_OFFICIAL},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Auditor
Role.AUDITOR: {
ResourceType.AUDIT_LOG: {Action.READ},
ResourceType.EXAM_PACKAGE: {Action.READ}, # Nur Metadaten
# Kein Zugriff auf Inhalte!
},
# Operator
Role.OPERATOR: {
ResourceType.TENANT: {Action.READ},
ResourceType.NAMESPACE: {Action.READ},
ResourceType.EXAM_PACKAGE: {Action.READ}, # Nur Metadaten
ResourceType.AUDIT_LOG: {Action.READ},
# Break-glass separat gehandhabt
},
# Teacher Assistant
Role.TEACHER_ASSISTANT: {
ResourceType.STUDENT_WORK: {Action.READ},
ResourceType.ANNOTATION: {Action.CREATE, Action.READ}, # Nur bestimmte Typen
ResourceType.EH_DOCUMENT: {Action.READ},
},
# Exam Author (nur Vorabi)
Role.EXAM_AUTHOR: {
ResourceType.EH_DOCUMENT: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.RUBRIC: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
},
# =============================================
# ZEUGNIS-WORKFLOW ROLLEN
# =============================================
# Klassenlehrer - Erstellt Zeugnisse, Kopfnoten, Bemerkungen
Role.KLASSENLEHRER: {
ResourceType.NAMESPACE: {Action.READ},
ResourceType.ZEUGNIS: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_ENTWURF: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.ZEUGNIS_VORLAGE: {Action.READ},
ResourceType.SCHUELER_DATEN: {Action.READ, Action.UPDATE},
ResourceType.FACHNOTE: {Action.READ}, # Liest Fachnoten der Fachlehrer
ResourceType.KOPFNOTE: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.FEHLZEITEN: {Action.READ, Action.UPDATE},
ResourceType.BEMERKUNG: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.VERSETZUNG: {Action.READ},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Fachlehrer - Traegt Fachnoten ein
Role.FACHLEHRER: {
ResourceType.NAMESPACE: {Action.READ},
ResourceType.SCHUELER_DATEN: {Action.READ}, # Nur eigene Schueler
ResourceType.FACHNOTE: {Action.CREATE, Action.READ, Action.UPDATE}, # Nur eigenes Fach
ResourceType.BEMERKUNG: {Action.CREATE, Action.READ}, # Fachbezogene Bemerkungen
ResourceType.AUDIT_LOG: {Action.READ},
},
# Zeugnisbeauftragter - Qualitaetskontrolle
Role.ZEUGNISBEAUFTRAGTER: {
ResourceType.NAMESPACE: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_ENTWURF: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_VORLAGE: {Action.READ, Action.UPDATE, Action.UPLOAD},
ResourceType.SCHUELER_DATEN: {Action.READ},
ResourceType.FACHNOTE: {Action.READ},
ResourceType.KOPFNOTE: {Action.READ, Action.UPDATE},
ResourceType.FEHLZEITEN: {Action.READ},
ResourceType.BEMERKUNG: {Action.READ, Action.UPDATE},
ResourceType.VERSETZUNG: {Action.READ},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Sekretariat - Druck, Versand, Archivierung
Role.SEKRETARIAT: {
ResourceType.ZEUGNIS: {Action.READ, Action.DOWNLOAD},
ResourceType.ZEUGNIS_VORLAGE: {Action.READ},
ResourceType.SCHUELER_DATEN: {Action.READ}, # Fuer Adressdaten
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Schulleitung - Finale Zeugnis-Freigabe
Role.SCHULLEITUNG: {
ResourceType.TENANT: {Action.READ},
ResourceType.NAMESPACE: {Action.READ, Action.CREATE},
ResourceType.ZEUGNIS: {Action.READ, Action.SIGN_OFF, Action.LOCK},
ResourceType.ZEUGNIS_ENTWURF: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_VORLAGE: {Action.READ, Action.UPDATE},
ResourceType.SCHUELER_DATEN: {Action.READ},
ResourceType.FACHNOTE: {Action.READ},
ResourceType.KOPFNOTE: {Action.READ, Action.UPDATE},
ResourceType.FEHLZEITEN: {Action.READ},
ResourceType.BEMERKUNG: {Action.READ, Action.UPDATE},
ResourceType.KONFERENZ_BESCHLUSS: {Action.CREATE, Action.READ, Action.UPDATE, Action.SIGN_OFF},
ResourceType.VERSETZUNG: {Action.CREATE, Action.READ, Action.UPDATE, Action.SIGN_OFF},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Stufenleitung - Stufenkoordination (z.B. Oberstufe)
Role.STUFENLEITUNG: {
ResourceType.NAMESPACE: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_ENTWURF: {Action.READ, Action.UPDATE},
ResourceType.SCHUELER_DATEN: {Action.READ},
ResourceType.FACHNOTE: {Action.READ},
ResourceType.KOPFNOTE: {Action.READ},
ResourceType.FEHLZEITEN: {Action.READ},
ResourceType.BEMERKUNG: {Action.READ, Action.UPDATE},
ResourceType.KONFERENZ_BESCHLUSS: {Action.READ},
ResourceType.VERSETZUNG: {Action.READ, Action.UPDATE},
ResourceType.EXPORT: {Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
}
@@ -0,0 +1,438 @@
"""
RBAC/ABAC Type Definitions
Enums, data structures, and models for the policy system.
Extracted from rbac.py for file-size compliance.
"""
import json
from enum import Enum
from dataclasses import dataclass, field, asdict
from typing import Optional, List, Dict, Set, Any
from datetime import datetime, timezone
import uuid
# =============================================
# ENUMS: Roles, Actions, Resources
# =============================================
class Role(str, Enum):
"""Fachliche Rollen in Korrektur- und Zeugniskette."""
# === Klausur-Korrekturkette ===
ERSTKORREKTOR = "erstkorrektor" # EK
ZWEITKORREKTOR = "zweitkorrektor" # ZK
DRITTKORREKTOR = "drittkorrektor" # DK
# === Zeugnis-Workflow ===
KLASSENLEHRER = "klassenlehrer" # KL - Erstellt Zeugnis, Kopfnoten, Bemerkungen
FACHLEHRER = "fachlehrer" # FL - Traegt Fachnoten ein
ZEUGNISBEAUFTRAGTER = "zeugnisbeauftragter" # ZB - Qualitaetskontrolle
SEKRETARIAT = "sekretariat" # SEK - Druck, Versand, Archivierung
# === Leitung (Klausur + Zeugnis) ===
FACHVORSITZ = "fachvorsitz" # FVL - Fachpruefungsleitung
PRUEFUNGSVORSITZ = "pruefungsvorsitz" # PV - Schulleitung / Pruefungsvorsitz
SCHULLEITUNG = "schulleitung" # SL - Finale Zeugnis-Freigabe
STUFENLEITUNG = "stufenleitung" # STL - Stufenkoordination
# === Administration ===
SCHUL_ADMIN = "schul_admin" # SA
LAND_ADMIN = "land_admin" # LA - Behoerde
# === Spezial ===
AUDITOR = "auditor" # DSB/Auditor
OPERATOR = "operator" # OPS - Support
TEACHER_ASSISTANT = "teacher_assistant" # TA - Referendar
EXAM_AUTHOR = "exam_author" # EA - nur Vorabi
class Action(str, Enum):
"""Moegliche Operationen auf Ressourcen."""
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
ASSIGN_ROLE = "assign_role"
INVITE_USER = "invite_user"
REMOVE_USER = "remove_user"
UPLOAD = "upload"
DOWNLOAD = "download"
LOCK = "lock" # Finalisieren
UNLOCK = "unlock" # Nur mit Sonderrecht
SIGN_OFF = "sign_off" # Freigabe
SHARE_KEY = "share_key" # Key Share erzeugen
VIEW_PII = "view_pii" # Falls PII vorhanden
BREAK_GLASS = "break_glass" # Notfallzugriff
PUBLISH_OFFICIAL = "publish_official" # Amtliche EH verteilen
class ResourceType(str, Enum):
"""Ressourcentypen im System."""
TENANT = "tenant"
NAMESPACE = "namespace"
# === Klausur-Korrektur ===
EXAM_PACKAGE = "exam_package"
STUDENT_WORK = "student_work"
EH_DOCUMENT = "eh_document"
RUBRIC = "rubric" # Punkteraster
ANNOTATION = "annotation"
EVALUATION = "evaluation" # Kriterien/Punkte
REPORT = "report" # Gutachten
GRADE_DECISION = "grade_decision"
# === Zeugnisgenerator ===
ZEUGNIS = "zeugnis" # Zeugnisdokument
ZEUGNIS_VORLAGE = "zeugnis_vorlage" # Zeugnisvorlage/Template
ZEUGNIS_ENTWURF = "zeugnis_entwurf" # Zeugnisentwurf (vor Freigabe)
SCHUELER_DATEN = "schueler_daten" # Schueler-Stammdaten, Noten
FACHNOTE = "fachnote" # Einzelne Fachnote
KOPFNOTE = "kopfnote" # Arbeits-/Sozialverhalten
FEHLZEITEN = "fehlzeiten" # Fehlzeiten
BEMERKUNG = "bemerkung" # Zeugnisbemerkungen
KONFERENZ_BESCHLUSS = "konferenz_beschluss" # Konferenzergebnis
VERSETZUNG = "versetzung" # Versetzungsentscheidung
# === Allgemein ===
DOCUMENT = "document" # Generischer Dokumenttyp (EH, Vorlagen, etc.)
TEMPLATE = "template" # Generische Vorlagen
EXPORT = "export"
AUDIT_LOG = "audit_log"
KEY_MATERIAL = "key_material"
class ZKVisibilityMode(str, Enum):
"""Sichtbarkeitsmodus fuer Zweitkorrektoren."""
BLIND = "blind" # ZK sieht keine EK-Note/Gutachten
SEMI = "semi" # ZK sieht Annotationen, aber keine Note
FULL = "full" # ZK sieht alles
class EHVisibilityMode(str, Enum):
"""Sichtbarkeitsmodus fuer Erwartungshorizonte."""
BLIND = "blind" # ZK sieht EH nicht (selten)
SHARED = "shared" # ZK sieht EH (Standard)
class VerfahrenType(str, Enum):
"""Verfahrenstypen fuer Klausuren und Zeugnisse."""
# === Klausur/Pruefungsverfahren ===
ABITUR = "abitur"
VORABITUR = "vorabitur"
KLAUSUR = "klausur"
NACHPRUEFUNG = "nachpruefung"
# === Zeugnisverfahren ===
HALBJAHRESZEUGNIS = "halbjahreszeugnis"
JAHRESZEUGNIS = "jahreszeugnis"
ABSCHLUSSZEUGNIS = "abschlusszeugnis"
ABGANGSZEUGNIS = "abgangszeugnis"
@classmethod
def is_exam_type(cls, verfahren: str) -> bool:
"""Pruefe ob Verfahren ein Pruefungstyp ist."""
exam_types = {cls.ABITUR, cls.VORABITUR, cls.KLAUSUR, cls.NACHPRUEFUNG}
try:
return cls(verfahren) in exam_types
except ValueError:
return False
@classmethod
def is_certificate_type(cls, verfahren: str) -> bool:
"""Pruefe ob Verfahren ein Zeugnistyp ist."""
cert_types = {cls.HALBJAHRESZEUGNIS, cls.JAHRESZEUGNIS, cls.ABSCHLUSSZEUGNIS, cls.ABGANGSZEUGNIS}
try:
return cls(verfahren) in cert_types
except ValueError:
return False
# =============================================
# DATA STRUCTURES
# =============================================
@dataclass
class PolicySet:
"""
Policy-Konfiguration pro Bundesland/Jahr/Fach.
Ermoeglicht bundesland-spezifische Unterschiede ohne
harte Codierung im Quellcode.
Unterstuetzte Verfahrenstypen:
- Pruefungen: abitur, vorabitur, klausur, nachpruefung
- Zeugnisse: halbjahreszeugnis, jahreszeugnis, abschlusszeugnis, abgangszeugnis
"""
id: str
bundesland: str
jahr: int
fach: Optional[str] # None = gilt fuer alle Faecher
verfahren: str # See VerfahrenType enum
# Sichtbarkeitsregeln (Klausur)
zk_visibility_mode: ZKVisibilityMode = ZKVisibilityMode.FULL
eh_visibility_mode: EHVisibilityMode = EHVisibilityMode.SHARED
# EH-Quellen (Klausur)
allow_teacher_uploaded_eh: bool = True
allow_land_uploaded_eh: bool = True
require_rights_confirmation_on_upload: bool = True
require_dual_control_for_official_eh_update: bool = False
# Korrekturregeln (Klausur)
third_correction_threshold: int = 4 # Notenpunkte Abweichung
final_signoff_role: str = "fachvorsitz"
# Zeugnisregeln (Zeugnis)
require_klassenlehrer_approval: bool = True
require_schulleitung_signoff: bool = True
allow_sekretariat_edit_after_approval: bool = False
konferenz_protokoll_required: bool = True
bemerkungen_require_review: bool = True
fehlzeiten_auto_import: bool = True
kopfnoten_enabled: bool = False
versetzung_auto_calculate: bool = True
# Export & Anzeige
quote_verbatim_allowed: bool = False # Amtliche Texte in UI
export_template_id: str = "default"
# Zusaetzliche Flags
flags: Dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def is_exam_policy(self) -> bool:
"""Pruefe ob diese Policy fuer Pruefungen ist."""
return VerfahrenType.is_exam_type(self.verfahren)
def is_certificate_policy(self) -> bool:
"""Pruefe ob diese Policy fuer Zeugnisse ist."""
return VerfahrenType.is_certificate_type(self.verfahren)
def to_dict(self):
d = asdict(self)
d['zk_visibility_mode'] = self.zk_visibility_mode.value
d['eh_visibility_mode'] = self.eh_visibility_mode.value
d['created_at'] = self.created_at.isoformat()
return d
@dataclass
class RoleAssignment:
"""
Zuweisung einer Rolle zu einem User fuer eine spezifische Ressource.
"""
id: str
user_id: str
role: Role
resource_type: ResourceType
resource_id: str
# Optionale Einschraenkungen
tenant_id: Optional[str] = None
namespace_id: Optional[str] = None
# Gueltigkeit
valid_from: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
valid_to: Optional[datetime] = None
# Metadaten
granted_by: str = ""
granted_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
revoked_at: Optional[datetime] = None
def is_active(self) -> bool:
now = datetime.now(timezone.utc)
if self.revoked_at:
return False
if self.valid_to and now > self.valid_to:
return False
return now >= self.valid_from
def to_dict(self):
return {
'id': self.id,
'user_id': self.user_id,
'role': self.role.value,
'resource_type': self.resource_type.value,
'resource_id': self.resource_id,
'tenant_id': self.tenant_id,
'namespace_id': self.namespace_id,
'valid_from': self.valid_from.isoformat(),
'valid_to': self.valid_to.isoformat() if self.valid_to else None,
'granted_by': self.granted_by,
'granted_at': self.granted_at.isoformat(),
'revoked_at': self.revoked_at.isoformat() if self.revoked_at else None,
'is_active': self.is_active()
}
@dataclass
class KeyShare:
"""
Berechtigung fuer einen User, auf verschluesselte Inhalte zuzugreifen.
Ein KeyShare ist KEIN Schluessel im Klartext, sondern eine
Berechtigung in Verbindung mit Role Assignment.
"""
id: str
user_id: str
package_id: str
# Berechtigungsumfang
permissions: Set[str] = field(default_factory=set)
# z.B. {"read_original", "read_eh", "read_ek_outputs", "write_annotations"}
# Optionale Einschraenkungen
scope: str = "full" # "full", "original_only", "eh_only", "outputs_only"
# Kette
granted_by: str = ""
granted_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
# Akzeptanz (fuer Invite-Flow)
invite_token: Optional[str] = None
accepted_at: Optional[datetime] = None
# Widerruf
revoked_at: Optional[datetime] = None
revoked_by: Optional[str] = None
def is_active(self) -> bool:
return self.revoked_at is None and (
self.invite_token is None or self.accepted_at is not None
)
def to_dict(self):
return {
'id': self.id,
'user_id': self.user_id,
'package_id': self.package_id,
'permissions': list(self.permissions),
'scope': self.scope,
'granted_by': self.granted_by,
'granted_at': self.granted_at.isoformat(),
'invite_token': self.invite_token,
'accepted_at': self.accepted_at.isoformat() if self.accepted_at else None,
'revoked_at': self.revoked_at.isoformat() if self.revoked_at else None,
'is_active': self.is_active()
}
@dataclass
class Tenant:
"""
Hoechste Isolationseinheit - typischerweise eine Schule.
"""
id: str
name: str
bundesland: str
tenant_type: str = "school" # "school", "pruefungszentrum", "behoerde"
# Verschluesselung
encryption_enabled: bool = True
# Metadaten
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
deleted_at: Optional[datetime] = None
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'bundesland': self.bundesland,
'tenant_type': self.tenant_type,
'encryption_enabled': self.encryption_enabled,
'created_at': self.created_at.isoformat()
}
@dataclass
class Namespace:
"""
Arbeitsraum innerhalb eines Tenants.
z.B. "Abitur 2026 - Deutsch LK - Kurs 12a"
"""
id: str
tenant_id: str
name: str
# Kontext
jahr: int
fach: str
kurs: Optional[str] = None
pruefungsart: str = "abitur" # "abitur", "vorabitur"
# Policy
policy_set_id: Optional[str] = None
# Metadaten
created_by: str = ""
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
deleted_at: Optional[datetime] = None
def to_dict(self):
return {
'id': self.id,
'tenant_id': self.tenant_id,
'name': self.name,
'jahr': self.jahr,
'fach': self.fach,
'kurs': self.kurs,
'pruefungsart': self.pruefungsart,
'policy_set_id': self.policy_set_id,
'created_by': self.created_by,
'created_at': self.created_at.isoformat()
}
@dataclass
class ExamPackage:
"""
Pruefungspaket - kompletter Satz Arbeiten mit allen Artefakten.
"""
id: str
namespace_id: str
tenant_id: str
name: str
beschreibung: Optional[str] = None
# Workflow-Status
status: str = "draft" # "draft", "in_progress", "locked", "signed_off"
# Beteiligte (Rollen werden separat zugewiesen)
owner_id: str = "" # Typischerweise EK
# Verschluesselung
encryption_key_id: Optional[str] = None
# Timestamps
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
locked_at: Optional[datetime] = None
signed_off_at: Optional[datetime] = None
signed_off_by: Optional[str] = None
def to_dict(self):
return {
'id': self.id,
'namespace_id': self.namespace_id,
'tenant_id': self.tenant_id,
'name': self.name,
'beschreibung': self.beschreibung,
'status': self.status,
'owner_id': self.owner_id,
'created_at': self.created_at.isoformat(),
'locked_at': self.locked_at.isoformat() if self.locked_at else None,
'signed_off_at': self.signed_off_at.isoformat() if self.signed_off_at else None,
'signed_off_by': self.signed_off_by
}
@@ -1,200 +1,4 @@
"""
Compliance Extraction & Generation.
Functions for extracting checkpoints from legal text chunks,
generating controls, and creating remediation measures.
"""
import re
import hashlib
import logging
from typing import Dict, List, Optional
from compliance_models import Checkpoint, Control, Measure
logger = logging.getLogger(__name__)
def extract_checkpoints_from_chunk(chunk_text: str, payload: Dict) -> List[Checkpoint]:
"""
Extract checkpoints/requirements from a chunk of text.
Uses pattern matching to find requirement-like statements.
"""
checkpoints = []
regulation_code = payload.get("regulation_code", "UNKNOWN")
regulation_name = payload.get("regulation_name", "Unknown")
source_url = payload.get("source_url", "")
chunk_id = hashlib.md5(chunk_text[:100].encode()).hexdigest()[:8]
# Patterns for different requirement types
patterns = [
# BSI-TR patterns
(r'([OT]\.[A-Za-z_]+\d*)[:\s]+(.+?)(?=\n[OT]\.|$)', 'bsi_requirement'),
# Article patterns (GDPR, AI Act, etc.)
(r'(?:Artikel|Art\.?)\s+(\d+)(?:\s+Abs(?:atz)?\.?\s*(\d+))?\s*[-\u2013:]\s*(.+?)(?=\n|$)', 'article'),
# Numbered requirements
(r'\((\d+)\)\s+(.+?)(?=\n\(\d+\)|$)', 'numbered'),
# "Der Verantwortliche muss" patterns
(r'(?:Der Verantwortliche|Die Aufsichtsbeh\u00f6rde|Der Auftragsverarbeiter)\s+(muss|hat|soll)\s+(.+?)(?=\.\s|$)', 'obligation'),
# "Es ist erforderlich" patterns
(r'(?:Es ist erforderlich|Es muss gew\u00e4hrleistet|Es sind geeignete)\s+(.+?)(?=\.\s|$)', 'requirement'),
]
for pattern, pattern_type in patterns:
matches = re.finditer(pattern, chunk_text, re.MULTILINE | re.DOTALL)
for match in matches:
if pattern_type == 'bsi_requirement':
req_id = match.group(1)
description = match.group(2).strip()
title = req_id
elif pattern_type == 'article':
article_num = match.group(1)
paragraph = match.group(2) or ""
title_text = match.group(3).strip()
req_id = f"{regulation_code}-Art{article_num}"
if paragraph:
req_id += f"-{paragraph}"
title = f"Art. {article_num}" + (f" Abs. {paragraph}" if paragraph else "")
description = title_text
elif pattern_type == 'numbered':
num = match.group(1)
description = match.group(2).strip()
req_id = f"{regulation_code}-{num}"
title = f"Anforderung {num}"
else:
# Generic requirement
description = match.group(0).strip()
req_id = f"{regulation_code}-{chunk_id}-{len(checkpoints)}"
title = description[:50] + "..." if len(description) > 50 else description
# Skip very short matches
if len(description) < 20:
continue
checkpoint = Checkpoint(
id=req_id,
regulation_code=regulation_code,
regulation_name=regulation_name,
article=title if 'Art' in title else None,
title=title,
description=description[:500],
original_text=description,
chunk_id=chunk_id,
source_url=source_url
)
checkpoints.append(checkpoint)
return checkpoints
def generate_control_for_checkpoints(
checkpoints: List[Checkpoint],
domain_counts: Dict[str, int],
) -> Optional[Control]:
"""
Generate a control that covers the given checkpoints.
This is a simplified version - in production this would use the AI assistant.
"""
if not checkpoints:
return None
# Group by regulation
regulation = checkpoints[0].regulation_code
# Determine domain based on content
all_text = " ".join([cp.description for cp in checkpoints]).lower()
domain = "gov" # Default
if any(kw in all_text for kw in ["verschl\u00fcssel", "krypto", "encrypt", "hash"]):
domain = "crypto"
elif any(kw in all_text for kw in ["zugang", "access", "authentif", "login", "benutzer"]):
domain = "iam"
elif any(kw in all_text for kw in ["datenschutz", "personenbezogen", "privacy", "einwilligung"]):
domain = "priv"
elif any(kw in all_text for kw in ["entwicklung", "test", "code", "software"]):
domain = "sdlc"
elif any(kw in all_text for kw in ["\u00fcberwach", "monitor", "log", "audit"]):
domain = "aud"
elif any(kw in all_text for kw in ["ki", "k\u00fcnstlich", "ai", "machine learning", "model"]):
domain = "ai"
elif any(kw in all_text for kw in ["betrieb", "operation", "verf\u00fcgbar", "backup"]):
domain = "ops"
elif any(kw in all_text for kw in ["cyber", "resilience", "sbom", "vulnerab"]):
domain = "cra"
# Generate control ID
domain_count = domain_counts.get(domain, 0) + 1
control_id = f"{domain.upper()}-{domain_count:03d}"
# Create title from first checkpoint
title = checkpoints[0].title
if len(title) > 100:
title = title[:97] + "..."
# Create description
description = f"Control f\u00fcr {regulation}: " + checkpoints[0].description[:200]
# Pass criteria
pass_criteria = f"Alle {len(checkpoints)} zugeh\u00f6rigen Anforderungen sind erf\u00fcllt und dokumentiert."
# Implementation guidance
guidance = f"Implementiere Ma\u00dfnahmen zur Erf\u00fcllung der Anforderungen aus {regulation}. "
guidance += f"Dokumentiere die Umsetzung und f\u00fchre regelm\u00e4\u00dfige Reviews durch."
# Determine if automated
is_automated = any(kw in all_text for kw in ["automat", "tool", "scan", "test"])
control = Control(
id=control_id,
domain=domain,
title=title,
description=description,
checkpoints=[cp.id for cp in checkpoints],
pass_criteria=pass_criteria,
implementation_guidance=guidance,
is_automated=is_automated,
automation_tool="CI/CD Pipeline" if is_automated else None,
priority="high" if "muss" in all_text or "erforderlich" in all_text else "medium"
)
return control
def generate_measure_for_control(control: Control) -> Measure:
"""Generate a remediation measure for a control."""
measure_id = f"M-{control.id}"
# Determine deadline based on priority
deadline_days = {
"critical": 30,
"high": 60,
"medium": 90,
"low": 180
}.get(control.priority, 90)
# Determine responsible team
responsible = {
"priv": "Datenschutzbeauftragter",
"iam": "IT-Security Team",
"sdlc": "Entwicklungsteam",
"crypto": "IT-Security Team",
"ops": "Operations Team",
"aud": "Compliance Team",
"ai": "AI/ML Team",
"cra": "IT-Security Team",
"gov": "Management"
}.get(control.domain, "Compliance Team")
measure = Measure(
id=measure_id,
control_id=control.id,
title=f"Umsetzung: {control.title[:50]}",
description=f"Implementierung und Dokumentation von {control.id}: {control.description[:100]}",
responsible=responsible,
deadline_days=deadline_days,
status="pending"
)
return measure
# Backward-compat shim -- module moved to compliance/extraction.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("compliance.extraction")
+4 -49
View File
@@ -1,49 +1,4 @@
"""
Compliance Pipeline Data Models.
Dataclasses for checkpoints, controls, and measures.
"""
from typing import Optional, List
from dataclasses import dataclass
@dataclass
class Checkpoint:
"""A requirement/checkpoint extracted from legal text."""
id: str
regulation_code: str
regulation_name: str
article: Optional[str]
title: str
description: str
original_text: str
chunk_id: str
source_url: str
@dataclass
class Control:
"""A control derived from checkpoints."""
id: str
domain: str
title: str
description: str
checkpoints: List[str] # List of checkpoint IDs
pass_criteria: str
implementation_guidance: str
is_automated: bool
automation_tool: Optional[str]
priority: str
@dataclass
class Measure:
"""A remediation measure for a control."""
id: str
control_id: str
title: str
description: str
responsible: str
deadline_days: int
status: str
# Backward-compat shim -- module moved to compliance/models.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("compliance.models")
+4 -441
View File
@@ -1,441 +1,4 @@
"""
Compliance Pipeline Execution.
Pipeline phases (ingestion, extraction, control generation, measures)
and orchestration logic.
"""
import asyncio
import json
import logging
import os
import sys
import time
from datetime import datetime
from typing import Dict, List, Any
from dataclasses import asdict
from compliance_models import Checkpoint, Control, Measure
from compliance_extraction import (
extract_checkpoints_from_chunk,
generate_control_for_checkpoints,
generate_measure_for_control,
)
logger = logging.getLogger(__name__)
# Import checkpoint manager
try:
from pipeline_checkpoints import CheckpointManager, EXPECTED_VALUES, ValidationStatus
except ImportError:
logger.warning("Checkpoint manager not available, running without checkpoints")
CheckpointManager = None
EXPECTED_VALUES = {}
ValidationStatus = None
# Set environment variables for Docker network
if not os.getenv("QDRANT_URL") and not os.getenv("QDRANT_HOST"):
os.environ["QDRANT_HOST"] = "qdrant"
os.environ.setdefault("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
# Try to import from klausur-service
try:
from legal_corpus_ingestion import LegalCorpusIngestion, REGULATIONS, LEGAL_CORPUS_COLLECTION
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
except ImportError:
logger.error("Could not import required modules. Make sure you're in the klausur-service container.")
sys.exit(1)
class CompliancePipeline:
"""Handles the full compliance pipeline."""
def __init__(self):
# Support both QDRANT_URL and QDRANT_HOST/PORT
qdrant_url = os.getenv("QDRANT_URL", "")
if qdrant_url:
from urllib.parse import urlparse
parsed = urlparse(qdrant_url)
qdrant_host = parsed.hostname or "qdrant"
qdrant_port = parsed.port or 6333
else:
qdrant_host = os.getenv("QDRANT_HOST", "qdrant")
qdrant_port = 6333
self.qdrant = QdrantClient(host=qdrant_host, port=qdrant_port)
self.checkpoints: List[Checkpoint] = []
self.controls: List[Control] = []
self.measures: List[Measure] = []
self.stats = {
"chunks_processed": 0,
"checkpoints_extracted": 0,
"controls_created": 0,
"measures_defined": 0,
"by_regulation": {},
"by_domain": {},
}
# Initialize checkpoint manager
self.checkpoint_mgr = CheckpointManager() if CheckpointManager else None
async def run_ingestion_phase(self, force_reindex: bool = False) -> int:
"""Phase 1: Ingest documents (incremental - only missing ones)."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 1: DOCUMENT INGESTION (INCREMENTAL)")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("ingestion", "Document Ingestion")
ingestion = LegalCorpusIngestion()
try:
# Check existing chunks per regulation
existing_chunks = {}
try:
for regulation in REGULATIONS:
count_result = self.qdrant.count(
collection_name=LEGAL_CORPUS_COLLECTION,
count_filter=Filter(
must=[FieldCondition(key="regulation_code", match=MatchValue(value=regulation.code))]
)
)
existing_chunks[regulation.code] = count_result.count
logger.info(f" {regulation.code}: {count_result.count} existing chunks")
except Exception as e:
logger.warning(f"Could not check existing chunks: {e}")
# Determine which regulations need ingestion
regulations_to_ingest = []
for regulation in REGULATIONS:
existing = existing_chunks.get(regulation.code, 0)
if force_reindex or existing == 0:
regulations_to_ingest.append(regulation)
logger.info(f" -> Will ingest: {regulation.code} (existing: {existing}, force: {force_reindex})")
else:
logger.info(f" -> Skipping: {regulation.code} (already has {existing} chunks)")
self.stats["by_regulation"][regulation.code] = existing
if not regulations_to_ingest:
logger.info("All regulations already indexed. Skipping ingestion phase.")
total_chunks = sum(existing_chunks.values())
self.stats["chunks_processed"] = total_chunks
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("skipped", True)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
# Ingest only missing regulations
total_chunks = sum(existing_chunks.values())
for i, regulation in enumerate(regulations_to_ingest, 1):
logger.info(f"[{i}/{len(regulations_to_ingest)}] Ingesting {regulation.code}...")
try:
count = await ingestion.ingest_regulation(regulation)
total_chunks += count
self.stats["by_regulation"][regulation.code] = count
logger.info(f" -> {count} chunks")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric(f"chunks_{regulation.code}", count)
except Exception as e:
logger.error(f" -> FAILED: {e}")
self.stats["by_regulation"][regulation.code] = 0
self.stats["chunks_processed"] = total_chunks
logger.info(f"\nTotal chunks in collection: {total_chunks}")
# Validate ingestion results
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_chunks", total_chunks)
self.checkpoint_mgr.add_metric("regulations_count", len(REGULATIONS))
expected = EXPECTED_VALUES.get("ingestion", {})
self.checkpoint_mgr.validate(
"total_chunks",
expected=expected.get("total_chunks", 8000),
actual=total_chunks,
min_value=expected.get("min_chunks", 7000)
)
reg_expected = expected.get("regulations", {})
for reg_code, reg_exp in reg_expected.items():
actual = self.stats["by_regulation"].get(reg_code, 0)
self.checkpoint_mgr.validate(
f"chunks_{reg_code}",
expected=reg_exp.get("expected", 0),
actual=actual,
min_value=reg_exp.get("min", 0)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return total_chunks
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
finally:
await ingestion.close()
async def run_extraction_phase(self) -> int:
"""Phase 2: Extract checkpoints from chunks."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 2: CHECKPOINT EXTRACTION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("extraction", "Checkpoint Extraction")
try:
offset = None
total_checkpoints = 0
while True:
result = self.qdrant.scroll(
collection_name=LEGAL_CORPUS_COLLECTION,
limit=100,
offset=offset,
with_payload=True,
with_vectors=False
)
points, next_offset = result
if not points:
break
for point in points:
payload = point.payload
text = payload.get("text", "")
cps = extract_checkpoints_from_chunk(text, payload)
self.checkpoints.extend(cps)
total_checkpoints += len(cps)
logger.info(f"Processed {len(points)} chunks, extracted {total_checkpoints} checkpoints so far...")
if next_offset is None:
break
offset = next_offset
self.stats["checkpoints_extracted"] = len(self.checkpoints)
logger.info(f"\nTotal checkpoints extracted: {len(self.checkpoints)}")
by_reg = {}
for cp in self.checkpoints:
by_reg[cp.regulation_code] = by_reg.get(cp.regulation_code, 0) + 1
for reg, count in sorted(by_reg.items()):
logger.info(f" {reg}: {count} checkpoints")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_checkpoints", len(self.checkpoints))
self.checkpoint_mgr.add_metric("checkpoints_by_regulation", by_reg)
expected = EXPECTED_VALUES.get("extraction", {})
self.checkpoint_mgr.validate(
"total_checkpoints",
expected=expected.get("total_checkpoints", 3500),
actual=len(self.checkpoints),
min_value=expected.get("min_checkpoints", 3000)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.checkpoints)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_control_generation_phase(self) -> int:
"""Phase 3: Generate controls from checkpoints."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 3: CONTROL GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("controls", "Control Generation")
try:
# Group checkpoints by regulation
by_regulation: Dict[str, List[Checkpoint]] = {}
for cp in self.checkpoints:
reg = cp.regulation_code
if reg not in by_regulation:
by_regulation[reg] = []
by_regulation[reg].append(cp)
# Generate controls per regulation (group every 3-5 checkpoints)
for regulation, checkpoints in by_regulation.items():
logger.info(f"Generating controls for {regulation} ({len(checkpoints)} checkpoints)...")
batch_size = 4
for i in range(0, len(checkpoints), batch_size):
batch = checkpoints[i:i + batch_size]
control = generate_control_for_checkpoints(batch, self.stats.get("by_domain", {}))
if control:
self.controls.append(control)
self.stats["by_domain"][control.domain] = self.stats["by_domain"].get(control.domain, 0) + 1
self.stats["controls_created"] = len(self.controls)
logger.info(f"\nTotal controls created: {len(self.controls)}")
for domain, count in sorted(self.stats["by_domain"].items()):
logger.info(f" {domain}: {count} controls")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_controls", len(self.controls))
self.checkpoint_mgr.add_metric("controls_by_domain", dict(self.stats["by_domain"]))
expected = EXPECTED_VALUES.get("controls", {})
self.checkpoint_mgr.validate(
"total_controls",
expected=expected.get("total_controls", 900),
actual=len(self.controls),
min_value=expected.get("min_controls", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.controls)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
async def run_measure_generation_phase(self) -> int:
"""Phase 4: Generate measures for controls."""
logger.info("\n" + "=" * 60)
logger.info("PHASE 4: MEASURE GENERATION")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.start_checkpoint("measures", "Measure Generation")
try:
for control in self.controls:
measure = generate_measure_for_control(control)
self.measures.append(measure)
self.stats["measures_defined"] = len(self.measures)
logger.info(f"\nTotal measures defined: {len(self.measures)}")
if self.checkpoint_mgr:
self.checkpoint_mgr.add_metric("total_measures", len(self.measures))
expected = EXPECTED_VALUES.get("measures", {})
self.checkpoint_mgr.validate(
"total_measures",
expected=expected.get("total_measures", 900),
actual=len(self.measures),
min_value=expected.get("min_measures", 800)
)
self.checkpoint_mgr.complete_checkpoint(success=True)
return len(self.measures)
except Exception as e:
if self.checkpoint_mgr:
self.checkpoint_mgr.fail_checkpoint(str(e))
raise
def save_results(self, output_dir: str = "/tmp/compliance_output"):
"""Save results to JSON files."""
logger.info("\n" + "=" * 60)
logger.info("SAVING RESULTS")
logger.info("=" * 60)
os.makedirs(output_dir, exist_ok=True)
checkpoints_file = os.path.join(output_dir, "checkpoints.json")
with open(checkpoints_file, "w") as f:
json.dump([asdict(cp) for cp in self.checkpoints], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.checkpoints)} checkpoints to {checkpoints_file}")
controls_file = os.path.join(output_dir, "controls.json")
with open(controls_file, "w") as f:
json.dump([asdict(c) for c in self.controls], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.controls)} controls to {controls_file}")
measures_file = os.path.join(output_dir, "measures.json")
with open(measures_file, "w") as f:
json.dump([asdict(m) for m in self.measures], f, indent=2, ensure_ascii=False)
logger.info(f"Saved {len(self.measures)} measures to {measures_file}")
stats_file = os.path.join(output_dir, "statistics.json")
self.stats["generated_at"] = datetime.now().isoformat()
with open(stats_file, "w") as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
logger.info(f"Saved statistics to {stats_file}")
async def run_full_pipeline(self, force_reindex: bool = False, skip_ingestion: bool = False):
"""Run the complete pipeline.
Args:
force_reindex: If True, re-ingest all documents even if they exist
skip_ingestion: If True, skip ingestion phase entirely (use existing chunks)
"""
start_time = time.time()
logger.info("=" * 60)
logger.info("FULL COMPLIANCE PIPELINE (INCREMENTAL)")
logger.info(f"Started at: {datetime.now().isoformat()}")
logger.info(f"Force reindex: {force_reindex}")
logger.info(f"Skip ingestion: {skip_ingestion}")
if self.checkpoint_mgr:
logger.info(f"Pipeline ID: {self.checkpoint_mgr.pipeline_id}")
logger.info("=" * 60)
try:
if skip_ingestion:
logger.info("Skipping ingestion phase as requested...")
try:
collection_info = self.qdrant.get_collection(LEGAL_CORPUS_COLLECTION)
self.stats["chunks_processed"] = collection_info.points_count
except Exception:
self.stats["chunks_processed"] = 0
else:
await self.run_ingestion_phase(force_reindex=force_reindex)
await self.run_extraction_phase()
await self.run_control_generation_phase()
await self.run_measure_generation_phase()
self.save_results()
elapsed = time.time() - start_time
logger.info("\n" + "=" * 60)
logger.info("PIPELINE COMPLETE")
logger.info("=" * 60)
logger.info(f"Duration: {elapsed:.1f} seconds")
logger.info(f"Chunks processed: {self.stats['chunks_processed']}")
logger.info(f"Checkpoints extracted: {self.stats['checkpoints_extracted']}")
logger.info(f"Controls created: {self.stats['controls_created']}")
logger.info(f"Measures defined: {self.stats['measures_defined']}")
logger.info(f"\nResults saved to: /tmp/compliance_output/")
logger.info("Checkpoint status: /tmp/pipeline_checkpoints.json")
logger.info("=" * 60)
if self.checkpoint_mgr:
self.checkpoint_mgr.complete_pipeline({
"duration_seconds": elapsed,
"chunks_processed": self.stats['chunks_processed'],
"checkpoints_extracted": self.stats['checkpoints_extracted'],
"controls_created": self.stats['controls_created'],
"measures_defined": self.stats['measures_defined'],
"by_regulation": self.stats['by_regulation'],
"by_domain": self.stats['by_domain'],
})
except Exception as e:
logger.error(f"Pipeline failed: {e}")
if self.checkpoint_mgr:
self.checkpoint_mgr.state.status = "failed"
self.checkpoint_mgr._save()
raise
# Backward-compat shim -- module moved to compliance/pipeline.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("compliance.pipeline")
+4 -420
View File
@@ -1,420 +1,4 @@
"""
BYOEH Processing Pipeline
Handles chunking, embedding generation, and encryption for Erwartungshorizonte.
Supports multiple embedding backends:
- local: sentence-transformers (default, no API key needed)
- openai: OpenAI text-embedding-3-small (requires OPENAI_API_KEY)
"""
import os
import io
import base64
import hashlib
from typing import List, Tuple, Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
import httpx
# Embedding Configuration
# Backend: "local" (sentence-transformers) or "openai"
EMBEDDING_BACKEND = os.getenv("EMBEDDING_BACKEND", "local")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
# Local embedding model (all-MiniLM-L6-v2: 384 dimensions, fast, good quality)
LOCAL_EMBEDDING_MODEL = os.getenv("LOCAL_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
# Vector dimensions per backend
VECTOR_DIMENSIONS = {
"local": 384, # all-MiniLM-L6-v2
"openai": 1536, # text-embedding-3-small
}
CHUNK_SIZE = int(os.getenv("BYOEH_CHUNK_SIZE", "1000"))
CHUNK_OVERLAP = int(os.getenv("BYOEH_CHUNK_OVERLAP", "200"))
# Lazy-loaded sentence-transformers model
_local_model = None
class ChunkingError(Exception):
"""Error during text chunking."""
pass
class EmbeddingError(Exception):
"""Error during embedding generation."""
pass
class EncryptionError(Exception):
"""Error during encryption/decryption."""
pass
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""
Split text into overlapping chunks.
Uses a simple recursive character splitter approach:
- Try to split on paragraph boundaries first
- Then sentences
- Then words
- Finally characters
Args:
text: Input text to chunk
chunk_size: Target chunk size in characters
overlap: Overlap between chunks
Returns:
List of text chunks
"""
if not text or len(text) <= chunk_size:
return [text] if text else []
chunks = []
separators = ["\n\n", "\n", ". ", " ", ""]
def split_recursive(text: str, sep_idx: int = 0) -> List[str]:
if len(text) <= chunk_size:
return [text]
if sep_idx >= len(separators):
# Last resort: hard split
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - overlap)]
sep = separators[sep_idx]
if not sep:
# Empty separator = character split
parts = list(text)
else:
parts = text.split(sep)
result = []
current = ""
for part in parts:
test_chunk = current + sep + part if current else part
if len(test_chunk) <= chunk_size:
current = test_chunk
else:
if current:
result.append(current)
# If single part is too big, recursively split it
if len(part) > chunk_size:
result.extend(split_recursive(part, sep_idx + 1))
current = ""
else:
current = part
if current:
result.append(current)
return result
raw_chunks = split_recursive(text)
# Add overlap
final_chunks = []
for i, chunk in enumerate(raw_chunks):
if i > 0 and overlap > 0:
# Add overlap from previous chunk
prev_chunk = raw_chunks[i-1]
overlap_text = prev_chunk[-min(overlap, len(prev_chunk)):]
chunk = overlap_text + chunk
final_chunks.append(chunk.strip())
return [c for c in final_chunks if c]
def get_vector_size() -> int:
"""Get the vector dimension for the current embedding backend."""
return VECTOR_DIMENSIONS.get(EMBEDDING_BACKEND, 384)
def _get_local_model():
"""Lazy-load the sentence-transformers model."""
global _local_model
if _local_model is None:
try:
from sentence_transformers import SentenceTransformer
print(f"Loading local embedding model: {LOCAL_EMBEDDING_MODEL}")
_local_model = SentenceTransformer(LOCAL_EMBEDDING_MODEL)
print(f"Model loaded successfully (dim={_local_model.get_sentence_embedding_dimension()})")
except ImportError:
raise EmbeddingError(
"sentence-transformers not installed. "
"Install with: pip install sentence-transformers"
)
return _local_model
def _generate_local_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings using local sentence-transformers model."""
if not texts:
return []
model = _get_local_model()
embeddings = model.encode(texts, show_progress_bar=len(texts) > 10)
return [emb.tolist() for emb in embeddings]
async def _generate_openai_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings using OpenAI API."""
if not OPENAI_API_KEY:
raise EmbeddingError("OPENAI_API_KEY not configured")
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/embeddings",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": EMBEDDING_MODEL,
"input": texts
},
timeout=60.0
)
if response.status_code != 200:
raise EmbeddingError(f"OpenAI API error: {response.status_code} - {response.text}")
data = response.json()
embeddings = [item["embedding"] for item in data["data"]]
return embeddings
except httpx.TimeoutException:
raise EmbeddingError("OpenAI API timeout")
except Exception as e:
raise EmbeddingError(f"Failed to generate embeddings: {str(e)}")
async def generate_embeddings(texts: List[str]) -> List[List[float]]:
"""
Generate embeddings using configured backend.
Backends:
- local: sentence-transformers (default, no API key needed)
- openai: OpenAI text-embedding-3-small
Args:
texts: List of text chunks
Returns:
List of embedding vectors
Raises:
EmbeddingError: If embedding generation fails
"""
if not texts:
return []
if EMBEDDING_BACKEND == "local":
# Local model runs synchronously but is fast
return _generate_local_embeddings(texts)
elif EMBEDDING_BACKEND == "openai":
return await _generate_openai_embeddings(texts)
else:
raise EmbeddingError(f"Unknown embedding backend: {EMBEDDING_BACKEND}")
async def generate_single_embedding(text: str) -> List[float]:
"""Generate embedding for a single text."""
embeddings = await generate_embeddings([text])
return embeddings[0] if embeddings else []
def derive_key(passphrase: str, salt: bytes) -> bytes:
"""
Derive encryption key from passphrase using PBKDF2.
Args:
passphrase: User passphrase
salt: Random salt (16 bytes)
Returns:
32-byte AES key
"""
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
return kdf.derive(passphrase.encode())
def encrypt_text(text: str, passphrase: str, salt_hex: str) -> str:
"""
Encrypt text using AES-256-GCM.
Args:
text: Plaintext to encrypt
passphrase: User passphrase
salt_hex: Salt as hex string
Returns:
Base64-encoded ciphertext (IV + ciphertext)
"""
try:
salt = bytes.fromhex(salt_hex)
key = derive_key(passphrase, salt)
aesgcm = AESGCM(key)
iv = os.urandom(12)
ciphertext = aesgcm.encrypt(iv, text.encode(), None)
# Combine IV + ciphertext
combined = iv + ciphertext
return base64.b64encode(combined).decode()
except Exception as e:
raise EncryptionError(f"Encryption failed: {str(e)}")
def decrypt_text(encrypted_b64: str, passphrase: str, salt_hex: str) -> str:
"""
Decrypt text using AES-256-GCM.
Args:
encrypted_b64: Base64-encoded ciphertext (IV + ciphertext)
passphrase: User passphrase
salt_hex: Salt as hex string
Returns:
Decrypted plaintext
"""
try:
salt = bytes.fromhex(salt_hex)
key = derive_key(passphrase, salt)
combined = base64.b64decode(encrypted_b64)
iv = combined[:12]
ciphertext = combined[12:]
aesgcm = AESGCM(key)
plaintext = aesgcm.decrypt(iv, ciphertext, None)
return plaintext.decode()
except Exception as e:
raise EncryptionError(f"Decryption failed: {str(e)}")
def hash_key(passphrase: str, salt_hex: str) -> str:
"""
Create SHA-256 hash of derived key for verification.
Args:
passphrase: User passphrase
salt_hex: Salt as hex string
Returns:
Hex-encoded key hash
"""
salt = bytes.fromhex(salt_hex)
key = derive_key(passphrase, salt)
return hashlib.sha256(key).hexdigest()
def verify_key_hash(passphrase: str, salt_hex: str, expected_hash: str) -> bool:
"""
Verify passphrase matches stored key hash.
Args:
passphrase: User passphrase to verify
salt_hex: Salt as hex string
expected_hash: Expected key hash
Returns:
True if passphrase is correct
"""
computed_hash = hash_key(passphrase, salt_hex)
return computed_hash == expected_hash
def extract_text_from_pdf(pdf_content: bytes) -> str:
"""
Extract text from PDF file.
Args:
pdf_content: Raw PDF bytes
Returns:
Extracted text
"""
try:
import PyPDF2
pdf_file = io.BytesIO(pdf_content)
reader = PyPDF2.PdfReader(pdf_file)
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
return "\n\n".join(text_parts)
except ImportError:
raise ChunkingError("PyPDF2 not installed")
except Exception as e:
raise ChunkingError(f"Failed to extract PDF text: {str(e)}")
async def process_eh_for_indexing(
eh_id: str,
tenant_id: str,
subject: str,
text_content: str,
passphrase: str,
salt_hex: str
) -> Tuple[int, List[dict]]:
"""
Full processing pipeline for Erwartungshorizont indexing.
1. Chunk the text
2. Generate embeddings
3. Encrypt chunks
4. Return prepared data for Qdrant
Args:
eh_id: Erwartungshorizont ID
tenant_id: Tenant ID
subject: Subject (deutsch, englisch, etc.)
text_content: Decrypted text content
passphrase: User passphrase for re-encryption
salt_hex: Salt for encryption
Returns:
Tuple of (chunk_count, chunks_data)
"""
# 1. Chunk the text
chunks = chunk_text(text_content)
if not chunks:
return 0, []
# 2. Generate embeddings
embeddings = await generate_embeddings(chunks)
# 3. Encrypt chunks for storage
encrypted_chunks = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
encrypted_content = encrypt_text(chunk, passphrase, salt_hex)
encrypted_chunks.append({
"chunk_index": i,
"embedding": embedding,
"encrypted_content": encrypted_content
})
return len(chunks), encrypted_chunks
# Backward-compat shim -- module moved to korrektur/eh_pipeline.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.eh_pipeline")
+4 -34
View File
@@ -1,34 +1,4 @@
"""
Erwartungshorizont Templates for Vorabitur Mode — barrel re-export.
The actual code lives in:
- eh_templates_types.py (AUFGABENTYPEN, EHKriterium, EHTemplate)
- eh_templates_analyse.py (Textanalyse, Gedicht, Prosa, Drama)
- eh_templates_eroerterung.py (Eroerterung textgebunden)
- eh_templates_registry.py (TEMPLATES, get_template, list_templates, etc.)
"""
# Types
from eh_templates_types import ( # noqa: F401
AUFGABENTYPEN,
EHKriterium,
EHTemplate,
)
# Template factories
from eh_templates_analyse import ( # noqa: F401
get_textanalyse_template,
get_gedichtanalyse_template,
get_prosaanalyse_template,
get_dramenanalyse_template,
)
from eh_templates_eroerterung import get_eroerterung_template # noqa: F401
# Registry
from eh_templates_registry import ( # noqa: F401
TEMPLATES,
initialize_templates,
get_template,
list_templates,
get_aufgabentypen,
)
# Backward-compat shim -- module moved to korrektur/eh_templates.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.eh_templates")
+4 -395
View File
@@ -1,395 +1,4 @@
"""
Erwartungshorizont Templates — Analyse templates.
Contains templates for:
- Textanalyse (pragmatische Texte)
- Gedichtanalyse / Lyrikinterpretation
- Prosaanalyse
- Dramenanalyse
"""
from eh_templates_types import EHTemplate, EHKriterium
def get_textanalyse_template() -> EHTemplate:
"""Template for pragmatic text analysis."""
return EHTemplate(
id="template_textanalyse_pragmatisch",
aufgabentyp="textanalyse_pragmatisch",
name="Textanalyse pragmatischer Texte",
beschreibung="Vorlage fuer die Analyse von Sachtexten, Reden, Kommentaren und Essays",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Wiedergabe des Textinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Textaussage/These",
"Vollstaendige Wiedergabe der Argumentationsstruktur",
"Erkennen von Intention und Adressatenbezug",
"Einordnung in den historischen/gesellschaftlichen Kontext",
"Beruecksichtigung aller relevanten Textaspekte"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau und Gliederung der Analyse",
gewichtung=15,
erwartungen=[
"Sinnvolle Einleitung mit Basisinformationen",
"Logische Gliederung des Hauptteils",
"Stringente Gedankenfuehrung",
"Angemessener Schluss mit Fazit/Wertung",
"Absatzgliederung und Ueberlaenge"
]
),
EHKriterium(
id="analyse",
name="Analytische Qualitaet",
beschreibung="Tiefe und Qualitaet der Analyse",
gewichtung=15,
erwartungen=[
"Erkennen rhetorischer Mittel",
"Funktionale Deutung der Stilmittel",
"Analyse der Argumentationsweise",
"Beruecksichtigung von Wortwahl und Satzbau",
"Verknuepfung von Form und Inhalt"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung",
"Korrekte Getrennt- und Zusammenschreibung",
"Korrekte Fremdwortschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Flexion",
"Korrekte Zeichensetzung",
"Korrekte Bezuege und Kongruenz"
]
)
],
einleitung_hinweise=[
"Nennung von Autor, Titel, Textsorte, Erscheinungsjahr",
"Benennung des Themas",
"Formulierung der Kernthese/Hauptaussage",
"Ggf. Einordnung in den Kontext"
],
hauptteil_hinweise=[
"Systematische Analyse der Argumentationsstruktur",
"Untersuchung der sprachlichen Gestaltung",
"Funktionale Deutung der Stilmittel",
"Beruecksichtigung von Adressatenbezug und Intention",
"Textbelege durch Zitate"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Bewertung der Ueberzeugungskraft",
"Ggf. aktuelle Relevanz",
"Persoenliche Stellungnahme (wenn gefordert)"
],
sprachliche_aspekte=[
"Fachsprachliche Begriffe korrekt verwenden",
"Konjunktiv fuer indirekte Rede",
"Praesens als Tempus der Analyse",
"Sachlicher, analytischer Stil"
]
)
def get_gedichtanalyse_template() -> EHTemplate:
"""Template for poetry analysis."""
return EHTemplate(
id="template_gedichtanalyse",
aufgabentyp="gedichtanalyse",
name="Gedichtanalyse / Lyrikinterpretation",
beschreibung="Vorlage fuer die Analyse und Interpretation lyrischer Texte",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Gedichtinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung des lyrischen Ichs und der Sprechsituation",
"Vollstaendige inhaltliche Erschliessung aller Strophen",
"Erkennen der zentralen Motive und Themen",
"Epochenzuordnung und literaturgeschichtliche Einordnung",
"Deutung der Bildlichkeit und Symbolik"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Interpretation",
gewichtung=15,
erwartungen=[
"Einleitung mit Basisinformationen",
"Systematische strophenweise oder aspektorientierte Analyse",
"Verknuepfung von Form- und Inhaltsanalyse",
"Schluessige Gesamtdeutung im Schluss"
]
),
EHKriterium(
id="formanalyse",
name="Formale Analyse",
beschreibung="Analyse der lyrischen Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Bestimmung von Metrum und Reimschema",
"Analyse der Klanggestaltung",
"Erkennen von Enjambements und Zaesuren",
"Deutung der formalen Mittel",
"Verknuepfung von Form und Inhalt"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung",
"Korrekte Getrennt- und Zusammenschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Flexion",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Entstehungsjahr/Epoche",
"Thema/Motiv des Gedichts",
"Erste Deutungshypothese",
"Formale Grunddaten (Strophen, Verse)"
],
hauptteil_hinweise=[
"Inhaltliche Analyse (strophenweise oder aspektorientiert)",
"Formale Analyse (Metrum, Reim, Klang)",
"Sprachliche Analyse (Stilmittel, Bildlichkeit)",
"Funktionale Verknuepfung aller Ebenen",
"Textbelege durch Zitate mit Versangabe"
],
schluss_hinweise=[
"Zusammenfassung der Interpretationsergebnisse",
"Bestaetigung/Modifikation der Deutungshypothese",
"Einordnung in Epoche/Werk des Autors",
"Aktualitaetsbezug (wenn sinnvoll)"
],
sprachliche_aspekte=[
"Fachbegriffe der Lyrikanalyse verwenden",
"Zwischen lyrischem Ich und Autor unterscheiden",
"Praesens als Analysetempus",
"Deutende statt beschreibende Formulierungen"
]
)
def get_prosaanalyse_template() -> EHTemplate:
"""Template for prose/narrative text analysis."""
return EHTemplate(
id="template_prosaanalyse",
aufgabentyp="prosaanalyse",
name="Epische Textanalyse / Prosaanalyse",
beschreibung="Vorlage fuer die Analyse von Romanauszuegen, Kurzgeschichten und Novellen",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Textinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Handlung",
"Charakterisierung der Figuren",
"Erkennen der Erzaehlsituation",
"Deutung der Konflikte und Motive",
"Einordnung in den Gesamtzusammenhang"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Analyse",
gewichtung=15,
erwartungen=[
"Informative Einleitung",
"Systematische Analyse im Hauptteil",
"Verknuepfung der Analyseergebnisse",
"Schluessige Gesamtdeutung"
]
),
EHKriterium(
id="erzaehltechnik",
name="Erzaehltechnische Analyse",
beschreibung="Analyse narrativer Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Bestimmung der Erzaehlperspektive",
"Analyse von Zeitgestaltung",
"Raumgestaltung und Atmosphaere",
"Figurenrede und Bewusstseinsdarstellung",
"Funktionale Deutung"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Textsorte, Erscheinungsjahr",
"Einordnung des Auszugs in den Gesamttext",
"Thema und Deutungshypothese"
],
hauptteil_hinweise=[
"Kurze Inhaltsangabe des Auszugs",
"Analyse der Handlungsstruktur",
"Figurenanalyse mit Textbelegen",
"Erzaehltechnische Analyse",
"Sprachliche Analyse",
"Verknuepfung aller Ebenen"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Bestaetigung der Deutungshypothese",
"Bedeutung fuer Gesamtwerk",
"Ggf. Aktualitaetsbezug"
],
sprachliche_aspekte=[
"Fachbegriffe der Erzaehltextanalyse",
"Zwischen Erzaehler und Autor unterscheiden",
"Praesens als Analysetempus",
"Deutende Formulierungen"
]
)
def get_dramenanalyse_template() -> EHTemplate:
"""Template for drama analysis."""
return EHTemplate(
id="template_dramenanalyse",
aufgabentyp="dramenanalyse",
name="Dramenanalyse",
beschreibung="Vorlage fuer die Analyse dramatischer Texte und Szenen",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Szeneninhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Handlung",
"Analyse der Figurenkonstellation",
"Erkennen des dramatischen Konflikts",
"Einordnung in den Handlungsverlauf",
"Deutung der Szene im Gesamtzusammenhang"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Analyse",
gewichtung=15,
erwartungen=[
"Einleitung mit Kontextualisierung",
"Systematische Szenenanalyse",
"Verknuepfung der Analyseergebnisse",
"Schluessige Deutung"
]
),
EHKriterium(
id="dramentechnik",
name="Dramentechnische Analyse",
beschreibung="Analyse dramatischer Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Analyse der Dialoggestaltung",
"Regieanweisungen und Buehnenraum",
"Dramatische Spannung",
"Monolog/Dialog-Formen",
"Funktionale Deutung"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Urauffuehrungsjahr, Dramenform",
"Einordnung der Szene in den Handlungsverlauf",
"Thema und Deutungshypothese"
],
hauptteil_hinweise=[
"Situierung der Szene",
"Analyse des Dialogverlaufs",
"Figurenanalyse im Dialog",
"Sprachliche Analyse",
"Dramentechnische Mittel",
"Bedeutung fuer den Konflikt"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Funktion der Szene im Drama",
"Bedeutung fuer die Gesamtdeutung"
],
sprachliche_aspekte=[
"Fachbegriffe der Dramenanalyse",
"Praesens als Analysetempus",
"Korrekte Zitierweise mit Akt/Szene/Zeile"
]
)
# Backward-compat shim -- module moved to korrektur/eh_templates_analyse.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.eh_templates_analyse")
@@ -1,101 +1,4 @@
"""
Erwartungshorizont Templates — Eroerterung template.
"""
from eh_templates_types import EHTemplate, EHKriterium
def get_eroerterung_template() -> EHTemplate:
"""Template for textgebundene Eroerterung."""
return EHTemplate(
id="template_eroerterung_textgebunden",
aufgabentyp="eroerterung_textgebunden",
name="Textgebundene Eroerterung",
beschreibung="Vorlage fuer die textgebundene Eroerterung auf Basis eines Sachtextes",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Qualitaet der Argumentation",
gewichtung=40,
erwartungen=[
"Korrekte Wiedergabe der Textposition",
"Differenzierte eigene Argumentation",
"Vielfaeltige und ueberzeugende Argumente",
"Beruecksichtigung von Pro und Contra",
"Sinnvolle Beispiele und Belege",
"Eigenstaendige Schlussfolgerung"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Eroerterung",
gewichtung=15,
erwartungen=[
"Problemorientierte Einleitung",
"Klare Gliederung der Argumentation",
"Logische Argumentationsfolge",
"Sinnvolle Ueberlaetze",
"Begruendetes Fazit"
]
),
EHKriterium(
id="textbezug",
name="Textbezug",
beschreibung="Verknuepfung mit dem Ausgangstext",
gewichtung=15,
erwartungen=[
"Angemessene Textwiedergabe",
"Kritische Auseinandersetzung mit Textposition",
"Korrekte Zitierweise",
"Verknuepfung eigener Argumente mit Text"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung",
"Variationsreicher Ausdruck"
]
)
],
einleitung_hinweise=[
"Hinfuehrung zum Thema",
"Nennung des Ausgangstextes",
"Formulierung der Leitfrage/These",
"Ueberleitung zum Hauptteil"
],
hauptteil_hinweise=[
"Kurze Wiedergabe der Textposition",
"Systematische Argumentation (dialektisch oder linear)",
"Jedes Argument: These - Begruendung - Beispiel",
"Gewichtung der Argumente",
"Verknuepfung mit Textposition"
],
schluss_hinweise=[
"Zusammenfassung der wichtigsten Argumente",
"Eigene begruendete Stellungnahme",
"Ggf. Ausblick oder Appell"
],
sprachliche_aspekte=[
"Argumentative Konnektoren verwenden",
"Sachlicher, ueberzeugender Stil",
"Eigene Meinung kennzeichnen",
"Konjunktiv fuer Textpositionen"
]
)
# Backward-compat shim -- module moved to korrektur/eh_templates_eroerterung.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.eh_templates_eroerterung")
@@ -1,60 +1,4 @@
"""
Erwartungshorizont Templates — registry for template lookup.
"""
from typing import Dict, List, Optional
from eh_templates_types import EHTemplate, AUFGABENTYPEN
from eh_templates_analyse import (
get_textanalyse_template,
get_gedichtanalyse_template,
get_prosaanalyse_template,
get_dramenanalyse_template,
)
from eh_templates_eroerterung import get_eroerterung_template
TEMPLATES: Dict[str, EHTemplate] = {}
def initialize_templates():
"""Initialize all pre-defined templates."""
global TEMPLATES
TEMPLATES = {
"textanalyse_pragmatisch": get_textanalyse_template(),
"gedichtanalyse": get_gedichtanalyse_template(),
"eroerterung_textgebunden": get_eroerterung_template(),
"prosaanalyse": get_prosaanalyse_template(),
"dramenanalyse": get_dramenanalyse_template(),
}
def get_template(aufgabentyp: str) -> Optional[EHTemplate]:
"""Get a template by Aufgabentyp."""
if not TEMPLATES:
initialize_templates()
return TEMPLATES.get(aufgabentyp)
def list_templates() -> List[Dict]:
"""List all available templates."""
if not TEMPLATES:
initialize_templates()
return [
{
"aufgabentyp": typ,
"name": AUFGABENTYPEN.get(typ, {}).get("name", typ),
"description": AUFGABENTYPEN.get(typ, {}).get("description", ""),
"category": AUFGABENTYPEN.get(typ, {}).get("category", "other"),
}
for typ in TEMPLATES.keys()
]
def get_aufgabentypen() -> Dict:
"""Get all Aufgabentypen definitions."""
return AUFGABENTYPEN
# Initialize on import
initialize_templates()
# Backward-compat shim -- module moved to korrektur/eh_templates_registry.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.eh_templates_registry")
+4 -100
View File
@@ -1,100 +1,4 @@
"""
Erwartungshorizont Templates — types and Aufgabentypen registry.
"""
from typing import Dict, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
AUFGABENTYPEN = {
"textanalyse_pragmatisch": {
"name": "Textanalyse (pragmatische Texte)",
"description": "Analyse von Sachtexten, Reden, Kommentaren, Essays",
"category": "analyse"
},
"sachtextanalyse": {
"name": "Sachtextanalyse",
"description": "Analyse von informativen und appellativen Sachtexten",
"category": "analyse"
},
"gedichtanalyse": {
"name": "Gedichtanalyse / Lyrikinterpretation",
"description": "Analyse und Interpretation lyrischer Texte",
"category": "interpretation"
},
"dramenanalyse": {
"name": "Dramenanalyse",
"description": "Analyse dramatischer Texte und Szenen",
"category": "interpretation"
},
"prosaanalyse": {
"name": "Epische Textanalyse / Prosaanalyse",
"description": "Analyse von Romanauszuegen, Kurzgeschichten, Novellen",
"category": "interpretation"
},
"eroerterung_textgebunden": {
"name": "Textgebundene Eroerterung",
"description": "Eroerterung auf Basis eines Sachtextes",
"category": "argumentation"
},
"eroerterung_frei": {
"name": "Freie Eroerterung",
"description": "Freie Eroerterung zu einem Thema",
"category": "argumentation"
},
"eroerterung_literarisch": {
"name": "Literarische Eroerterung",
"description": "Eroerterung zu literarischen Fragestellungen",
"category": "argumentation"
},
"materialgestuetzt": {
"name": "Materialgestuetztes Schreiben",
"description": "Verfassen eines Textes auf Materialbasis",
"category": "produktion"
}
}
@dataclass
class EHKriterium:
"""Single criterion in an Erwartungshorizont."""
id: str
name: str
beschreibung: str
gewichtung: int # Percentage weight (0-100)
erwartungen: List[str] # Expected points/elements
max_punkte: int = 100
def to_dict(self):
return asdict(self)
@dataclass
class EHTemplate:
"""Complete Erwartungshorizont template."""
id: str
aufgabentyp: str
name: str
beschreibung: str
kriterien: List[EHKriterium]
einleitung_hinweise: List[str]
hauptteil_hinweise: List[str]
schluss_hinweise: List[str]
sprachliche_aspekte: List[str]
created_at: datetime = field(default_factory=lambda: datetime.now())
def to_dict(self):
d = {
'id': self.id,
'aufgabentyp': self.aufgabentyp,
'name': self.name,
'beschreibung': self.beschreibung,
'kriterien': [k.to_dict() for k in self.kriterien],
'einleitung_hinweise': self.einleitung_hinweise,
'hauptteil_hinweise': self.hauptteil_hinweise,
'schluss_hinweise': self.schluss_hinweise,
'sprachliche_aspekte': self.sprachliche_aspekte,
'created_at': self.created_at.isoformat()
}
return d
# Backward-compat shim -- module moved to korrektur/eh_templates_types.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.eh_templates_types")
@@ -1,65 +1,4 @@
#!/usr/bin/env python3
"""
Full Compliance Pipeline for Legal Corpus — Barrel Re-export.
Split into submodules:
- compliance_models.py — Dataclasses (Checkpoint, Control, Measure)
- compliance_extraction.py — Pattern extraction & control/measure generation
- compliance_pipeline.py — Pipeline phases & orchestrator
Run on Mac Mini:
nohup python full_compliance_pipeline.py > /tmp/compliance_pipeline.log 2>&1 &
"""
import asyncio
import logging
import sys
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('/tmp/compliance_pipeline.log')
]
)
# Re-export all public symbols
from compliance_models import Checkpoint, Control, Measure
from compliance_extraction import (
extract_checkpoints_from_chunk,
generate_control_for_checkpoints,
generate_measure_for_control,
)
from compliance_pipeline import CompliancePipeline
__all__ = [
"Checkpoint",
"Control",
"Measure",
"extract_checkpoints_from_chunk",
"generate_control_for_checkpoints",
"generate_measure_for_control",
"CompliancePipeline",
]
async def main():
import argparse
parser = argparse.ArgumentParser(description="Run the compliance pipeline")
parser.add_argument("--force-reindex", action="store_true",
help="Force re-ingestion of all documents")
parser.add_argument("--skip-ingestion", action="store_true",
help="Skip ingestion phase, use existing chunks")
args = parser.parse_args()
pipeline = CompliancePipeline()
await pipeline.run_full_pipeline(
force_reindex=args.force_reindex,
skip_ingestion=args.skip_ingestion
)
if __name__ == "__main__":
asyncio.run(main())
# Backward-compat shim -- module moved to compliance/full_pipeline.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("compliance.full_pipeline")
@@ -0,0 +1,6 @@
"""
korrektur package — exam correction, EH templates, PDF export.
Backward-compatible re-exports: consumers can still use
``from eh_pipeline import ...`` etc. via the shim files in backend/.
"""
@@ -0,0 +1,420 @@
"""
BYOEH Processing Pipeline
Handles chunking, embedding generation, and encryption for Erwartungshorizonte.
Supports multiple embedding backends:
- local: sentence-transformers (default, no API key needed)
- openai: OpenAI text-embedding-3-small (requires OPENAI_API_KEY)
"""
import os
import io
import base64
import hashlib
from typing import List, Tuple, Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from cryptography.hazmat.primitives import hashes
import httpx
# Embedding Configuration
# Backend: "local" (sentence-transformers) or "openai"
EMBEDDING_BACKEND = os.getenv("EMBEDDING_BACKEND", "local")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
# Local embedding model (all-MiniLM-L6-v2: 384 dimensions, fast, good quality)
LOCAL_EMBEDDING_MODEL = os.getenv("LOCAL_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
# Vector dimensions per backend
VECTOR_DIMENSIONS = {
"local": 384, # all-MiniLM-L6-v2
"openai": 1536, # text-embedding-3-small
}
CHUNK_SIZE = int(os.getenv("BYOEH_CHUNK_SIZE", "1000"))
CHUNK_OVERLAP = int(os.getenv("BYOEH_CHUNK_OVERLAP", "200"))
# Lazy-loaded sentence-transformers model
_local_model = None
class ChunkingError(Exception):
"""Error during text chunking."""
pass
class EmbeddingError(Exception):
"""Error during embedding generation."""
pass
class EncryptionError(Exception):
"""Error during encryption/decryption."""
pass
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""
Split text into overlapping chunks.
Uses a simple recursive character splitter approach:
- Try to split on paragraph boundaries first
- Then sentences
- Then words
- Finally characters
Args:
text: Input text to chunk
chunk_size: Target chunk size in characters
overlap: Overlap between chunks
Returns:
List of text chunks
"""
if not text or len(text) <= chunk_size:
return [text] if text else []
chunks = []
separators = ["\n\n", "\n", ". ", " ", ""]
def split_recursive(text: str, sep_idx: int = 0) -> List[str]:
if len(text) <= chunk_size:
return [text]
if sep_idx >= len(separators):
# Last resort: hard split
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size - overlap)]
sep = separators[sep_idx]
if not sep:
# Empty separator = character split
parts = list(text)
else:
parts = text.split(sep)
result = []
current = ""
for part in parts:
test_chunk = current + sep + part if current else part
if len(test_chunk) <= chunk_size:
current = test_chunk
else:
if current:
result.append(current)
# If single part is too big, recursively split it
if len(part) > chunk_size:
result.extend(split_recursive(part, sep_idx + 1))
current = ""
else:
current = part
if current:
result.append(current)
return result
raw_chunks = split_recursive(text)
# Add overlap
final_chunks = []
for i, chunk in enumerate(raw_chunks):
if i > 0 and overlap > 0:
# Add overlap from previous chunk
prev_chunk = raw_chunks[i-1]
overlap_text = prev_chunk[-min(overlap, len(prev_chunk)):]
chunk = overlap_text + chunk
final_chunks.append(chunk.strip())
return [c for c in final_chunks if c]
def get_vector_size() -> int:
"""Get the vector dimension for the current embedding backend."""
return VECTOR_DIMENSIONS.get(EMBEDDING_BACKEND, 384)
def _get_local_model():
"""Lazy-load the sentence-transformers model."""
global _local_model
if _local_model is None:
try:
from sentence_transformers import SentenceTransformer
print(f"Loading local embedding model: {LOCAL_EMBEDDING_MODEL}")
_local_model = SentenceTransformer(LOCAL_EMBEDDING_MODEL)
print(f"Model loaded successfully (dim={_local_model.get_sentence_embedding_dimension()})")
except ImportError:
raise EmbeddingError(
"sentence-transformers not installed. "
"Install with: pip install sentence-transformers"
)
return _local_model
def _generate_local_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings using local sentence-transformers model."""
if not texts:
return []
model = _get_local_model()
embeddings = model.encode(texts, show_progress_bar=len(texts) > 10)
return [emb.tolist() for emb in embeddings]
async def _generate_openai_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings using OpenAI API."""
if not OPENAI_API_KEY:
raise EmbeddingError("OPENAI_API_KEY not configured")
try:
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.openai.com/v1/embeddings",
headers={
"Authorization": f"Bearer {OPENAI_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": EMBEDDING_MODEL,
"input": texts
},
timeout=60.0
)
if response.status_code != 200:
raise EmbeddingError(f"OpenAI API error: {response.status_code} - {response.text}")
data = response.json()
embeddings = [item["embedding"] for item in data["data"]]
return embeddings
except httpx.TimeoutException:
raise EmbeddingError("OpenAI API timeout")
except Exception as e:
raise EmbeddingError(f"Failed to generate embeddings: {str(e)}")
async def generate_embeddings(texts: List[str]) -> List[List[float]]:
"""
Generate embeddings using configured backend.
Backends:
- local: sentence-transformers (default, no API key needed)
- openai: OpenAI text-embedding-3-small
Args:
texts: List of text chunks
Returns:
List of embedding vectors
Raises:
EmbeddingError: If embedding generation fails
"""
if not texts:
return []
if EMBEDDING_BACKEND == "local":
# Local model runs synchronously but is fast
return _generate_local_embeddings(texts)
elif EMBEDDING_BACKEND == "openai":
return await _generate_openai_embeddings(texts)
else:
raise EmbeddingError(f"Unknown embedding backend: {EMBEDDING_BACKEND}")
async def generate_single_embedding(text: str) -> List[float]:
"""Generate embedding for a single text."""
embeddings = await generate_embeddings([text])
return embeddings[0] if embeddings else []
def derive_key(passphrase: str, salt: bytes) -> bytes:
"""
Derive encryption key from passphrase using PBKDF2.
Args:
passphrase: User passphrase
salt: Random salt (16 bytes)
Returns:
32-byte AES key
"""
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
return kdf.derive(passphrase.encode())
def encrypt_text(text: str, passphrase: str, salt_hex: str) -> str:
"""
Encrypt text using AES-256-GCM.
Args:
text: Plaintext to encrypt
passphrase: User passphrase
salt_hex: Salt as hex string
Returns:
Base64-encoded ciphertext (IV + ciphertext)
"""
try:
salt = bytes.fromhex(salt_hex)
key = derive_key(passphrase, salt)
aesgcm = AESGCM(key)
iv = os.urandom(12)
ciphertext = aesgcm.encrypt(iv, text.encode(), None)
# Combine IV + ciphertext
combined = iv + ciphertext
return base64.b64encode(combined).decode()
except Exception as e:
raise EncryptionError(f"Encryption failed: {str(e)}")
def decrypt_text(encrypted_b64: str, passphrase: str, salt_hex: str) -> str:
"""
Decrypt text using AES-256-GCM.
Args:
encrypted_b64: Base64-encoded ciphertext (IV + ciphertext)
passphrase: User passphrase
salt_hex: Salt as hex string
Returns:
Decrypted plaintext
"""
try:
salt = bytes.fromhex(salt_hex)
key = derive_key(passphrase, salt)
combined = base64.b64decode(encrypted_b64)
iv = combined[:12]
ciphertext = combined[12:]
aesgcm = AESGCM(key)
plaintext = aesgcm.decrypt(iv, ciphertext, None)
return plaintext.decode()
except Exception as e:
raise EncryptionError(f"Decryption failed: {str(e)}")
def hash_key(passphrase: str, salt_hex: str) -> str:
"""
Create SHA-256 hash of derived key for verification.
Args:
passphrase: User passphrase
salt_hex: Salt as hex string
Returns:
Hex-encoded key hash
"""
salt = bytes.fromhex(salt_hex)
key = derive_key(passphrase, salt)
return hashlib.sha256(key).hexdigest()
def verify_key_hash(passphrase: str, salt_hex: str, expected_hash: str) -> bool:
"""
Verify passphrase matches stored key hash.
Args:
passphrase: User passphrase to verify
salt_hex: Salt as hex string
expected_hash: Expected key hash
Returns:
True if passphrase is correct
"""
computed_hash = hash_key(passphrase, salt_hex)
return computed_hash == expected_hash
def extract_text_from_pdf(pdf_content: bytes) -> str:
"""
Extract text from PDF file.
Args:
pdf_content: Raw PDF bytes
Returns:
Extracted text
"""
try:
import PyPDF2
pdf_file = io.BytesIO(pdf_content)
reader = PyPDF2.PdfReader(pdf_file)
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
return "\n\n".join(text_parts)
except ImportError:
raise ChunkingError("PyPDF2 not installed")
except Exception as e:
raise ChunkingError(f"Failed to extract PDF text: {str(e)}")
async def process_eh_for_indexing(
eh_id: str,
tenant_id: str,
subject: str,
text_content: str,
passphrase: str,
salt_hex: str
) -> Tuple[int, List[dict]]:
"""
Full processing pipeline for Erwartungshorizont indexing.
1. Chunk the text
2. Generate embeddings
3. Encrypt chunks
4. Return prepared data for Qdrant
Args:
eh_id: Erwartungshorizont ID
tenant_id: Tenant ID
subject: Subject (deutsch, englisch, etc.)
text_content: Decrypted text content
passphrase: User passphrase for re-encryption
salt_hex: Salt for encryption
Returns:
Tuple of (chunk_count, chunks_data)
"""
# 1. Chunk the text
chunks = chunk_text(text_content)
if not chunks:
return 0, []
# 2. Generate embeddings
embeddings = await generate_embeddings(chunks)
# 3. Encrypt chunks for storage
encrypted_chunks = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
encrypted_content = encrypt_text(chunk, passphrase, salt_hex)
encrypted_chunks.append({
"chunk_index": i,
"embedding": embedding,
"encrypted_content": encrypted_content
})
return len(chunks), encrypted_chunks
@@ -0,0 +1,34 @@
"""
Erwartungshorizont Templates for Vorabitur Mode — barrel re-export.
The actual code lives in:
- eh_templates_types.py (AUFGABENTYPEN, EHKriterium, EHTemplate)
- eh_templates_analyse.py (Textanalyse, Gedicht, Prosa, Drama)
- eh_templates_eroerterung.py (Eroerterung textgebunden)
- eh_templates_registry.py (TEMPLATES, get_template, list_templates, etc.)
"""
# Types
from .eh_templates_types import ( # noqa: F401
AUFGABENTYPEN,
EHKriterium,
EHTemplate,
)
# Template factories
from .eh_templates_analyse import ( # noqa: F401
get_textanalyse_template,
get_gedichtanalyse_template,
get_prosaanalyse_template,
get_dramenanalyse_template,
)
from .eh_templates_eroerterung import get_eroerterung_template # noqa: F401
# Registry
from .eh_templates_registry import ( # noqa: F401
TEMPLATES,
initialize_templates,
get_template,
list_templates,
get_aufgabentypen,
)
@@ -0,0 +1,395 @@
"""
Erwartungshorizont Templates — Analyse templates.
Contains templates for:
- Textanalyse (pragmatische Texte)
- Gedichtanalyse / Lyrikinterpretation
- Prosaanalyse
- Dramenanalyse
"""
from .eh_templates_types import EHTemplate, EHKriterium
def get_textanalyse_template() -> EHTemplate:
"""Template for pragmatic text analysis."""
return EHTemplate(
id="template_textanalyse_pragmatisch",
aufgabentyp="textanalyse_pragmatisch",
name="Textanalyse pragmatischer Texte",
beschreibung="Vorlage fuer die Analyse von Sachtexten, Reden, Kommentaren und Essays",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Wiedergabe des Textinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Textaussage/These",
"Vollstaendige Wiedergabe der Argumentationsstruktur",
"Erkennen von Intention und Adressatenbezug",
"Einordnung in den historischen/gesellschaftlichen Kontext",
"Beruecksichtigung aller relevanten Textaspekte"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau und Gliederung der Analyse",
gewichtung=15,
erwartungen=[
"Sinnvolle Einleitung mit Basisinformationen",
"Logische Gliederung des Hauptteils",
"Stringente Gedankenfuehrung",
"Angemessener Schluss mit Fazit/Wertung",
"Absatzgliederung und Ueberlaenge"
]
),
EHKriterium(
id="analyse",
name="Analytische Qualitaet",
beschreibung="Tiefe und Qualitaet der Analyse",
gewichtung=15,
erwartungen=[
"Erkennen rhetorischer Mittel",
"Funktionale Deutung der Stilmittel",
"Analyse der Argumentationsweise",
"Beruecksichtigung von Wortwahl und Satzbau",
"Verknuepfung von Form und Inhalt"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung",
"Korrekte Getrennt- und Zusammenschreibung",
"Korrekte Fremdwortschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Flexion",
"Korrekte Zeichensetzung",
"Korrekte Bezuege und Kongruenz"
]
)
],
einleitung_hinweise=[
"Nennung von Autor, Titel, Textsorte, Erscheinungsjahr",
"Benennung des Themas",
"Formulierung der Kernthese/Hauptaussage",
"Ggf. Einordnung in den Kontext"
],
hauptteil_hinweise=[
"Systematische Analyse der Argumentationsstruktur",
"Untersuchung der sprachlichen Gestaltung",
"Funktionale Deutung der Stilmittel",
"Beruecksichtigung von Adressatenbezug und Intention",
"Textbelege durch Zitate"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Bewertung der Ueberzeugungskraft",
"Ggf. aktuelle Relevanz",
"Persoenliche Stellungnahme (wenn gefordert)"
],
sprachliche_aspekte=[
"Fachsprachliche Begriffe korrekt verwenden",
"Konjunktiv fuer indirekte Rede",
"Praesens als Tempus der Analyse",
"Sachlicher, analytischer Stil"
]
)
def get_gedichtanalyse_template() -> EHTemplate:
"""Template for poetry analysis."""
return EHTemplate(
id="template_gedichtanalyse",
aufgabentyp="gedichtanalyse",
name="Gedichtanalyse / Lyrikinterpretation",
beschreibung="Vorlage fuer die Analyse und Interpretation lyrischer Texte",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Gedichtinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung des lyrischen Ichs und der Sprechsituation",
"Vollstaendige inhaltliche Erschliessung aller Strophen",
"Erkennen der zentralen Motive und Themen",
"Epochenzuordnung und literaturgeschichtliche Einordnung",
"Deutung der Bildlichkeit und Symbolik"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Interpretation",
gewichtung=15,
erwartungen=[
"Einleitung mit Basisinformationen",
"Systematische strophenweise oder aspektorientierte Analyse",
"Verknuepfung von Form- und Inhaltsanalyse",
"Schluessige Gesamtdeutung im Schluss"
]
),
EHKriterium(
id="formanalyse",
name="Formale Analyse",
beschreibung="Analyse der lyrischen Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Bestimmung von Metrum und Reimschema",
"Analyse der Klanggestaltung",
"Erkennen von Enjambements und Zaesuren",
"Deutung der formalen Mittel",
"Verknuepfung von Form und Inhalt"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung",
"Korrekte Getrennt- und Zusammenschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Flexion",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Entstehungsjahr/Epoche",
"Thema/Motiv des Gedichts",
"Erste Deutungshypothese",
"Formale Grunddaten (Strophen, Verse)"
],
hauptteil_hinweise=[
"Inhaltliche Analyse (strophenweise oder aspektorientiert)",
"Formale Analyse (Metrum, Reim, Klang)",
"Sprachliche Analyse (Stilmittel, Bildlichkeit)",
"Funktionale Verknuepfung aller Ebenen",
"Textbelege durch Zitate mit Versangabe"
],
schluss_hinweise=[
"Zusammenfassung der Interpretationsergebnisse",
"Bestaetigung/Modifikation der Deutungshypothese",
"Einordnung in Epoche/Werk des Autors",
"Aktualitaetsbezug (wenn sinnvoll)"
],
sprachliche_aspekte=[
"Fachbegriffe der Lyrikanalyse verwenden",
"Zwischen lyrischem Ich und Autor unterscheiden",
"Praesens als Analysetempus",
"Deutende statt beschreibende Formulierungen"
]
)
def get_prosaanalyse_template() -> EHTemplate:
"""Template for prose/narrative text analysis."""
return EHTemplate(
id="template_prosaanalyse",
aufgabentyp="prosaanalyse",
name="Epische Textanalyse / Prosaanalyse",
beschreibung="Vorlage fuer die Analyse von Romanauszuegen, Kurzgeschichten und Novellen",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Textinhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Handlung",
"Charakterisierung der Figuren",
"Erkennen der Erzaehlsituation",
"Deutung der Konflikte und Motive",
"Einordnung in den Gesamtzusammenhang"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Analyse",
gewichtung=15,
erwartungen=[
"Informative Einleitung",
"Systematische Analyse im Hauptteil",
"Verknuepfung der Analyseergebnisse",
"Schluessige Gesamtdeutung"
]
),
EHKriterium(
id="erzaehltechnik",
name="Erzaehltechnische Analyse",
beschreibung="Analyse narrativer Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Bestimmung der Erzaehlperspektive",
"Analyse von Zeitgestaltung",
"Raumgestaltung und Atmosphaere",
"Figurenrede und Bewusstseinsdarstellung",
"Funktionale Deutung"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Textsorte, Erscheinungsjahr",
"Einordnung des Auszugs in den Gesamttext",
"Thema und Deutungshypothese"
],
hauptteil_hinweise=[
"Kurze Inhaltsangabe des Auszugs",
"Analyse der Handlungsstruktur",
"Figurenanalyse mit Textbelegen",
"Erzaehltechnische Analyse",
"Sprachliche Analyse",
"Verknuepfung aller Ebenen"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Bestaetigung der Deutungshypothese",
"Bedeutung fuer Gesamtwerk",
"Ggf. Aktualitaetsbezug"
],
sprachliche_aspekte=[
"Fachbegriffe der Erzaehltextanalyse",
"Zwischen Erzaehler und Autor unterscheiden",
"Praesens als Analysetempus",
"Deutende Formulierungen"
]
)
def get_dramenanalyse_template() -> EHTemplate:
"""Template for drama analysis."""
return EHTemplate(
id="template_dramenanalyse",
aufgabentyp="dramenanalyse",
name="Dramenanalyse",
beschreibung="Vorlage fuer die Analyse dramatischer Texte und Szenen",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Erfassung und Deutung des Szeneninhalts",
gewichtung=40,
erwartungen=[
"Korrekte Erfassung der Handlung",
"Analyse der Figurenkonstellation",
"Erkennen des dramatischen Konflikts",
"Einordnung in den Handlungsverlauf",
"Deutung der Szene im Gesamtzusammenhang"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Analyse",
gewichtung=15,
erwartungen=[
"Einleitung mit Kontextualisierung",
"Systematische Szenenanalyse",
"Verknuepfung der Analyseergebnisse",
"Schluessige Deutung"
]
),
EHKriterium(
id="dramentechnik",
name="Dramentechnische Analyse",
beschreibung="Analyse dramatischer Gestaltungsmittel",
gewichtung=15,
erwartungen=[
"Analyse der Dialoggestaltung",
"Regieanweisungen und Buehnenraum",
"Dramatische Spannung",
"Monolog/Dialog-Formen",
"Funktionale Deutung"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung"
]
)
],
einleitung_hinweise=[
"Autor, Titel, Urauffuehrungsjahr, Dramenform",
"Einordnung der Szene in den Handlungsverlauf",
"Thema und Deutungshypothese"
],
hauptteil_hinweise=[
"Situierung der Szene",
"Analyse des Dialogverlaufs",
"Figurenanalyse im Dialog",
"Sprachliche Analyse",
"Dramentechnische Mittel",
"Bedeutung fuer den Konflikt"
],
schluss_hinweise=[
"Zusammenfassung der Analyseergebnisse",
"Funktion der Szene im Drama",
"Bedeutung fuer die Gesamtdeutung"
],
sprachliche_aspekte=[
"Fachbegriffe der Dramenanalyse",
"Praesens als Analysetempus",
"Korrekte Zitierweise mit Akt/Szene/Zeile"
]
)
@@ -0,0 +1,101 @@
"""
Erwartungshorizont Templates — Eroerterung template.
"""
from .eh_templates_types import EHTemplate, EHKriterium
def get_eroerterung_template() -> EHTemplate:
"""Template for textgebundene Eroerterung."""
return EHTemplate(
id="template_eroerterung_textgebunden",
aufgabentyp="eroerterung_textgebunden",
name="Textgebundene Eroerterung",
beschreibung="Vorlage fuer die textgebundene Eroerterung auf Basis eines Sachtextes",
kriterien=[
EHKriterium(
id="inhalt",
name="Inhaltliche Leistung",
beschreibung="Qualitaet der Argumentation",
gewichtung=40,
erwartungen=[
"Korrekte Wiedergabe der Textposition",
"Differenzierte eigene Argumentation",
"Vielfaeltige und ueberzeugende Argumente",
"Beruecksichtigung von Pro und Contra",
"Sinnvolle Beispiele und Belege",
"Eigenstaendige Schlussfolgerung"
]
),
EHKriterium(
id="struktur",
name="Aufbau und Struktur",
beschreibung="Logischer Aufbau der Eroerterung",
gewichtung=15,
erwartungen=[
"Problemorientierte Einleitung",
"Klare Gliederung der Argumentation",
"Logische Argumentationsfolge",
"Sinnvolle Ueberlaetze",
"Begruendetes Fazit"
]
),
EHKriterium(
id="textbezug",
name="Textbezug",
beschreibung="Verknuepfung mit dem Ausgangstext",
gewichtung=15,
erwartungen=[
"Angemessene Textwiedergabe",
"Kritische Auseinandersetzung mit Textposition",
"Korrekte Zitierweise",
"Verknuepfung eigener Argumente mit Text"
]
),
EHKriterium(
id="rechtschreibung",
name="Sprachliche Richtigkeit (Rechtschreibung)",
beschreibung="Orthografische Korrektheit",
gewichtung=15,
erwartungen=[
"Korrekte Rechtschreibung",
"Korrekte Gross- und Kleinschreibung"
]
),
EHKriterium(
id="grammatik",
name="Sprachliche Richtigkeit (Grammatik)",
beschreibung="Grammatische Korrektheit und Zeichensetzung",
gewichtung=15,
erwartungen=[
"Korrekter Satzbau",
"Korrekte Zeichensetzung",
"Variationsreicher Ausdruck"
]
)
],
einleitung_hinweise=[
"Hinfuehrung zum Thema",
"Nennung des Ausgangstextes",
"Formulierung der Leitfrage/These",
"Ueberleitung zum Hauptteil"
],
hauptteil_hinweise=[
"Kurze Wiedergabe der Textposition",
"Systematische Argumentation (dialektisch oder linear)",
"Jedes Argument: These - Begruendung - Beispiel",
"Gewichtung der Argumente",
"Verknuepfung mit Textposition"
],
schluss_hinweise=[
"Zusammenfassung der wichtigsten Argumente",
"Eigene begruendete Stellungnahme",
"Ggf. Ausblick oder Appell"
],
sprachliche_aspekte=[
"Argumentative Konnektoren verwenden",
"Sachlicher, ueberzeugender Stil",
"Eigene Meinung kennzeichnen",
"Konjunktiv fuer Textpositionen"
]
)
@@ -0,0 +1,60 @@
"""
Erwartungshorizont Templates — registry for template lookup.
"""
from typing import Dict, List, Optional
from .eh_templates_types import EHTemplate, AUFGABENTYPEN
from .eh_templates_analyse import (
get_textanalyse_template,
get_gedichtanalyse_template,
get_prosaanalyse_template,
get_dramenanalyse_template,
)
from .eh_templates_eroerterung import get_eroerterung_template
TEMPLATES: Dict[str, EHTemplate] = {}
def initialize_templates():
"""Initialize all pre-defined templates."""
global TEMPLATES
TEMPLATES = {
"textanalyse_pragmatisch": get_textanalyse_template(),
"gedichtanalyse": get_gedichtanalyse_template(),
"eroerterung_textgebunden": get_eroerterung_template(),
"prosaanalyse": get_prosaanalyse_template(),
"dramenanalyse": get_dramenanalyse_template(),
}
def get_template(aufgabentyp: str) -> Optional[EHTemplate]:
"""Get a template by Aufgabentyp."""
if not TEMPLATES:
initialize_templates()
return TEMPLATES.get(aufgabentyp)
def list_templates() -> List[Dict]:
"""List all available templates."""
if not TEMPLATES:
initialize_templates()
return [
{
"aufgabentyp": typ,
"name": AUFGABENTYPEN.get(typ, {}).get("name", typ),
"description": AUFGABENTYPEN.get(typ, {}).get("description", ""),
"category": AUFGABENTYPEN.get(typ, {}).get("category", "other"),
}
for typ in TEMPLATES.keys()
]
def get_aufgabentypen() -> Dict:
"""Get all Aufgabentypen definitions."""
return AUFGABENTYPEN
# Initialize on import
initialize_templates()
@@ -0,0 +1,100 @@
"""
Erwartungshorizont Templates — types and Aufgabentypen registry.
"""
from typing import Dict, List, Optional
from dataclasses import dataclass, field, asdict
from datetime import datetime
AUFGABENTYPEN = {
"textanalyse_pragmatisch": {
"name": "Textanalyse (pragmatische Texte)",
"description": "Analyse von Sachtexten, Reden, Kommentaren, Essays",
"category": "analyse"
},
"sachtextanalyse": {
"name": "Sachtextanalyse",
"description": "Analyse von informativen und appellativen Sachtexten",
"category": "analyse"
},
"gedichtanalyse": {
"name": "Gedichtanalyse / Lyrikinterpretation",
"description": "Analyse und Interpretation lyrischer Texte",
"category": "interpretation"
},
"dramenanalyse": {
"name": "Dramenanalyse",
"description": "Analyse dramatischer Texte und Szenen",
"category": "interpretation"
},
"prosaanalyse": {
"name": "Epische Textanalyse / Prosaanalyse",
"description": "Analyse von Romanauszuegen, Kurzgeschichten, Novellen",
"category": "interpretation"
},
"eroerterung_textgebunden": {
"name": "Textgebundene Eroerterung",
"description": "Eroerterung auf Basis eines Sachtextes",
"category": "argumentation"
},
"eroerterung_frei": {
"name": "Freie Eroerterung",
"description": "Freie Eroerterung zu einem Thema",
"category": "argumentation"
},
"eroerterung_literarisch": {
"name": "Literarische Eroerterung",
"description": "Eroerterung zu literarischen Fragestellungen",
"category": "argumentation"
},
"materialgestuetzt": {
"name": "Materialgestuetztes Schreiben",
"description": "Verfassen eines Textes auf Materialbasis",
"category": "produktion"
}
}
@dataclass
class EHKriterium:
"""Single criterion in an Erwartungshorizont."""
id: str
name: str
beschreibung: str
gewichtung: int # Percentage weight (0-100)
erwartungen: List[str] # Expected points/elements
max_punkte: int = 100
def to_dict(self):
return asdict(self)
@dataclass
class EHTemplate:
"""Complete Erwartungshorizont template."""
id: str
aufgabentyp: str
name: str
beschreibung: str
kriterien: List[EHKriterium]
einleitung_hinweise: List[str]
hauptteil_hinweise: List[str]
schluss_hinweise: List[str]
sprachliche_aspekte: List[str]
created_at: datetime = field(default_factory=lambda: datetime.now())
def to_dict(self):
d = {
'id': self.id,
'aufgabentyp': self.aufgabentyp,
'name': self.name,
'beschreibung': self.beschreibung,
'kriterien': [k.to_dict() for k in self.kriterien],
'einleitung_hinweise': self.einleitung_hinweise,
'hauptteil_hinweise': self.hauptteil_hinweise,
'schluss_hinweise': self.schluss_hinweise,
'sprachliche_aspekte': self.sprachliche_aspekte,
'created_at': self.created_at.isoformat()
}
return d
@@ -0,0 +1,17 @@
"""
PDF Export Module for Abiturkorrektur System
Barrel re-export: all PDF generation functions and constants.
"""
from .pdf_export_styles import ( # noqa: F401
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
CRITERIA_WEIGHTS,
get_custom_styles,
)
from .pdf_export_gutachten import generate_gutachten_pdf # noqa: F401
from .pdf_export_overview import ( # noqa: F401
generate_klausur_overview_pdf,
generate_annotations_pdf,
)
@@ -0,0 +1,315 @@
"""
PDF Export - Individual Gutachten PDF generation.
Generates a single student's Gutachten with criteria table,
workflow info, and annotation summary.
"""
import io
from datetime import datetime
from typing import Dict, List, Optional, Any
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import cm
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
HRFlowable, KeepTogether
)
from .pdf_export_styles import (
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
CRITERIA_WEIGHTS,
get_custom_styles,
)
def generate_gutachten_pdf(
student_data: Dict[str, Any],
klausur_data: Dict[str, Any],
annotations: List[Dict[str, Any]] = None,
workflow_data: Dict[str, Any] = None
) -> bytes:
"""
Generate a PDF Gutachten for a single student.
Args:
student_data: Student work data including criteria_scores, gutachten, grade_points
klausur_data: Klausur metadata (title, subject, year, etc.)
annotations: List of annotations for annotation summary
workflow_data: Examiner workflow data (EK, ZK, DK info)
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=2*cm,
leftMargin=2*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Gutachten zur Abiturklausur", styles['GutachtenTitle']))
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
# Meta information table
meta_data = [
["Pruefling:", student_data.get('student_name', 'Anonym')],
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
["Kurs:", klausur_data.get('semester', 'Abitur')],
["Datum:", datetime.now().strftime("%d.%m.%Y")]
]
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
meta_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(meta_table)
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Gutachten content
_add_gutachten_content(story, styles, student_data)
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Bewertungstabelle
_add_criteria_table(story, styles, student_data)
# Final grade box
_add_grade_box(story, styles, student_data)
# Examiner workflow information
if workflow_data:
_add_workflow_info(story, styles, workflow_data)
# Annotation summary
if annotations:
_add_annotation_summary(story, styles, annotations)
# Footer
_add_footer(story, styles)
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
def _add_gutachten_content(story, styles, student_data):
"""Add gutachten text sections to the story."""
gutachten = student_data.get('gutachten', {})
if gutachten:
if gutachten.get('einleitung'):
story.append(Paragraph("Einleitung", styles['SectionHeader']))
story.append(Paragraph(gutachten['einleitung'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('hauptteil'):
story.append(Paragraph("Hauptteil", styles['SectionHeader']))
story.append(Paragraph(gutachten['hauptteil'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('fazit'):
story.append(Paragraph("Fazit", styles['SectionHeader']))
story.append(Paragraph(gutachten['fazit'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('staerken') or gutachten.get('schwaechen'):
story.append(Spacer(1, 0.3*cm))
if gutachten.get('staerken'):
story.append(Paragraph("Staerken:", styles['SectionHeader']))
for s in gutachten['staerken']:
story.append(Paragraph(f"{s}", styles['ListItem']))
if gutachten.get('schwaechen'):
story.append(Paragraph("Verbesserungspotenzial:", styles['SectionHeader']))
for s in gutachten['schwaechen']:
story.append(Paragraph(f"{s}", styles['ListItem']))
else:
story.append(Paragraph("<i>Kein Gutachten-Text vorhanden.</i>", styles['GutachtenBody']))
def _add_criteria_table(story, styles, student_data):
"""Add criteria scoring table to the story."""
story.append(Paragraph("Bewertung nach Kriterien", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
criteria_scores = student_data.get('criteria_scores', {})
table_data = [["Kriterium", "Gewichtung", "Erreicht", "Punkte"]]
total_weighted = 0
total_weight = 0
for key, display_name in CRITERIA_DISPLAY_NAMES.items():
weight = CRITERIA_WEIGHTS.get(key, 0)
score_data = criteria_scores.get(key, {})
score = score_data.get('score', 0) if isinstance(score_data, dict) else score_data
weighted_score = (score / 100) * weight if score else 0
total_weighted += weighted_score
total_weight += weight
table_data.append([
display_name,
f"{weight}%",
f"{score}%",
f"{weighted_score:.1f}"
])
table_data.append([
"Gesamt",
f"{total_weight}%",
"",
f"{total_weighted:.1f}"
])
criteria_table = Table(table_data, colWidths=[8*cm, 2.5*cm, 2.5*cm, 2.5*cm])
criteria_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 10),
('ALIGN', (1, 0), (-1, -1), 'CENTER'),
('FONTSIZE', (0, 1), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('TOPPADDING', (0, 0), (-1, -1), 6),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
('BACKGROUND', (0, -1), (-1, -1), colors.HexColor('#f7fafc')),
('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'),
('ROWBACKGROUNDS', (0, 1), (-1, -2), [colors.white, colors.HexColor('#f7fafc')]),
]))
story.append(criteria_table)
story.append(Spacer(1, 0.5*cm))
def _add_grade_box(story, styles, student_data):
"""Add final grade box to the story."""
grade_points = student_data.get('grade_points', 0)
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "?")
raw_points = student_data.get('raw_points', 0)
grade_data = [
["Rohpunkte:", f"{raw_points} / 100"],
["Notenpunkte:", f"{grade_points} Punkte"],
["Note:", grade_note]
]
grade_table = Table(grade_data, colWidths=[4*cm, 4*cm])
grade_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#ebf8ff')),
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTNAME', (1, -1), (1, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 11),
('FONTSIZE', (1, -1), (1, -1), 14),
('TEXTCOLOR', (1, -1), (1, -1), colors.HexColor('#2c5282')),
('BOTTOMPADDING', (0, 0), (-1, -1), 8),
('TOPPADDING', (0, 0), (-1, -1), 8),
('LEFTPADDING', (0, 0), (-1, -1), 12),
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#2c5282')),
('ALIGN', (1, 0), (1, -1), 'RIGHT'),
]))
story.append(KeepTogether([
Paragraph("Endergebnis", styles['SectionHeader']),
Spacer(1, 0.2*cm),
grade_table
]))
def _add_workflow_info(story, styles, workflow_data):
"""Add examiner workflow information to the story."""
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.3*cm))
story.append(Paragraph("Korrekturverlauf", styles['SectionHeader']))
workflow_rows = []
if workflow_data.get('erst_korrektor'):
ek = workflow_data['erst_korrektor']
workflow_rows.append([
"Erstkorrektor:",
ek.get('name', 'Unbekannt'),
f"{ek.get('grade_points', '-')} Punkte"
])
if workflow_data.get('zweit_korrektor'):
zk = workflow_data['zweit_korrektor']
workflow_rows.append([
"Zweitkorrektor:",
zk.get('name', 'Unbekannt'),
f"{zk.get('grade_points', '-')} Punkte"
])
if workflow_data.get('dritt_korrektor'):
dk = workflow_data['dritt_korrektor']
workflow_rows.append([
"Drittkorrektor:",
dk.get('name', 'Unbekannt'),
f"{dk.get('grade_points', '-')} Punkte"
])
if workflow_data.get('final_grade_source'):
workflow_rows.append([
"Endnote durch:",
workflow_data['final_grade_source'],
""
])
if workflow_rows:
workflow_table = Table(workflow_rows, colWidths=[4*cm, 6*cm, 4*cm])
workflow_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(workflow_table)
def _add_annotation_summary(story, styles, annotations):
"""Add annotation summary to the story."""
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.3*cm))
story.append(Paragraph("Anmerkungen (Zusammenfassung)", styles['SectionHeader']))
by_type = {}
for ann in annotations:
ann_type = ann.get('type', 'comment')
if ann_type not in by_type:
by_type[ann_type] = []
by_type[ann_type].append(ann)
for ann_type, anns in by_type.items():
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
story.append(Paragraph(f"{type_name} ({len(anns)} Anmerkungen)", styles['ListItem']))
def _add_footer(story, styles):
"""Add generation footer to the story."""
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
@@ -0,0 +1,297 @@
"""
PDF Export - Klausur overview and annotations PDF generation.
Generates:
- Klausur overview with grade distribution for all students
- Annotations PDF for a single student
"""
import io
from datetime import datetime
from typing import Dict, List, Optional, Any
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import cm
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
HRFlowable
)
from .pdf_export_styles import (
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
get_custom_styles,
)
def generate_klausur_overview_pdf(
klausur_data: Dict[str, Any],
students: List[Dict[str, Any]],
fairness_data: Optional[Dict[str, Any]] = None
) -> bytes:
"""
Generate an overview PDF for an entire Klausur with all student grades.
Args:
klausur_data: Klausur metadata
students: List of all student work data
fairness_data: Optional fairness analysis data
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=1.5*cm,
leftMargin=1.5*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Notenuebersicht", styles['GutachtenTitle']))
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
# Meta information
meta_data = [
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
["Kurs:", klausur_data.get('semester', 'Abitur')],
["Anzahl Arbeiten:", str(len(students))],
["Stand:", datetime.now().strftime("%d.%m.%Y")]
]
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
meta_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(meta_table)
story.append(Spacer(1, 0.5*cm))
# Statistics (if fairness data available)
if fairness_data and fairness_data.get('statistics'):
_add_statistics(story, styles, fairness_data['statistics'])
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Student grades table
sorted_students = sorted(students, key=lambda s: s.get('grade_points', 0), reverse=True)
_add_student_table(story, styles, sorted_students)
# Grade distribution
_add_grade_distribution(story, styles, sorted_students)
# Footer
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
def _add_statistics(story, styles, stats):
"""Add statistics section."""
story.append(Paragraph("Statistik", styles['SectionHeader']))
stats_data = [
["Durchschnitt:", f"{stats.get('average_grade', 0):.1f} Punkte"],
["Minimum:", f"{stats.get('min_grade', 0)} Punkte"],
["Maximum:", f"{stats.get('max_grade', 0)} Punkte"],
["Standardabweichung:", f"{stats.get('standard_deviation', 0):.2f}"],
]
stats_table = Table(stats_data, colWidths=[4*cm, 4*cm])
stats_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#f7fafc')),
('BOX', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
]))
story.append(stats_table)
story.append(Spacer(1, 0.5*cm))
def _add_student_table(story, styles, sorted_students):
"""Add student grades table."""
story.append(Paragraph("Einzelergebnisse", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
table_data = [["#", "Name", "Rohpunkte", "Notenpunkte", "Note", "Status"]]
for idx, student in enumerate(sorted_students, 1):
grade_points = student.get('grade_points', 0)
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "-")
raw_points = student.get('raw_points', 0)
status = student.get('status', 'unknown')
status_display = {
'completed': 'Abgeschlossen',
'first_examiner': 'In Korrektur',
'second_examiner': 'Zweitkorrektur',
'uploaded': 'Hochgeladen',
'ocr_complete': 'OCR fertig',
'analyzing': 'Wird analysiert'
}.get(status, status)
table_data.append([
str(idx),
student.get('student_name', 'Anonym'),
f"{raw_points}/100",
str(grade_points),
grade_note,
status_display
])
student_table = Table(table_data, colWidths=[1*cm, 5*cm, 2.5*cm, 3*cm, 2*cm, 3*cm])
student_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 9),
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
('FONTSIZE', (0, 1), (-1, -1), 9),
('ALIGN', (0, 1), (0, -1), 'CENTER'),
('ALIGN', (2, 1), (4, -1), 'CENTER'),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('TOPPADDING', (0, 0), (-1, -1), 6),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#f7fafc')]),
]))
story.append(student_table)
def _add_grade_distribution(story, styles, sorted_students):
"""Add grade distribution table."""
story.append(Spacer(1, 0.5*cm))
story.append(Paragraph("Notenverteilung", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
grade_counts = {}
for student in sorted_students:
gp = student.get('grade_points', 0)
grade_counts[gp] = grade_counts.get(gp, 0) + 1
dist_data = [["Punkte", "Note", "Anzahl"]]
for points in range(15, -1, -1):
if points in grade_counts:
note = GRADE_POINTS_TO_NOTE.get(points, "-")
count = grade_counts[points]
dist_data.append([str(points), note, str(count)])
if len(dist_data) > 1:
dist_table = Table(dist_data, colWidths=[2.5*cm, 2.5*cm, 2.5*cm])
dist_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
]))
story.append(dist_table)
def generate_annotations_pdf(
student_data: Dict[str, Any],
klausur_data: Dict[str, Any],
annotations: List[Dict[str, Any]]
) -> bytes:
"""
Generate a PDF with all annotations for a student work.
Args:
student_data: Student work data
klausur_data: Klausur metadata
annotations: List of all annotations
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=2*cm,
leftMargin=2*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Anmerkungen zur Klausur", styles['GutachtenTitle']))
story.append(Paragraph(f"{student_data.get('student_name', 'Anonym')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
if not annotations:
story.append(Paragraph("<i>Keine Anmerkungen vorhanden.</i>", styles['GutachtenBody']))
else:
# Group by type
by_type = {}
for ann in annotations:
ann_type = ann.get('type', 'comment')
if ann_type not in by_type:
by_type[ann_type] = []
by_type[ann_type].append(ann)
for ann_type, anns in by_type.items():
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
story.append(Paragraph(f"{type_name} ({len(anns)})", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
sorted_anns = sorted(anns, key=lambda a: (a.get('page', 0), a.get('position', {}).get('y', 0)))
for idx, ann in enumerate(sorted_anns, 1):
page = ann.get('page', 1)
text = ann.get('text', '')
suggestion = ann.get('suggestion', '')
severity = ann.get('severity', 'minor')
ann_text = f"<b>[S.{page}]</b> {text}"
if suggestion:
ann_text += f" -> <i>{suggestion}</i>"
if severity == 'critical':
ann_text = f"<font color='red'>{ann_text}</font>"
elif severity == 'major':
ann_text = f"<font color='orange'>{ann_text}</font>"
story.append(Paragraph(f"{idx}. {ann_text}", styles['ListItem']))
story.append(Spacer(1, 0.3*cm))
# Footer
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
@@ -0,0 +1,110 @@
"""
PDF Export - Constants and ReportLab styles for Abiturkorrektur PDFs.
"""
from reportlab.lib import colors
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_JUSTIFY
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
# =============================================
# CONSTANTS
# =============================================
GRADE_POINTS_TO_NOTE = {
15: "1+", 14: "1", 13: "1-",
12: "2+", 11: "2", 10: "2-",
9: "3+", 8: "3", 7: "3-",
6: "4+", 5: "4", 4: "4-",
3: "5+", 2: "5", 1: "5-",
0: "6"
}
CRITERIA_DISPLAY_NAMES = {
"rechtschreibung": "Sprachliche Richtigkeit (Rechtschreibung)",
"grammatik": "Sprachliche Richtigkeit (Grammatik)",
"inhalt": "Inhaltliche Leistung",
"struktur": "Aufbau und Struktur",
"stil": "Ausdruck und Stil"
}
CRITERIA_WEIGHTS = {
"rechtschreibung": 15,
"grammatik": 15,
"inhalt": 40,
"struktur": 15,
"stil": 15
}
# =============================================
# STYLES
# =============================================
def get_custom_styles():
"""Create custom paragraph styles for Gutachten."""
styles = getSampleStyleSheet()
# Title style
styles.add(ParagraphStyle(
name='GutachtenTitle',
parent=styles['Heading1'],
fontSize=16,
spaceAfter=12,
alignment=TA_CENTER,
textColor=colors.HexColor('#1e3a5f')
))
# Subtitle style
styles.add(ParagraphStyle(
name='GutachtenSubtitle',
parent=styles['Heading2'],
fontSize=12,
spaceAfter=8,
spaceBefore=16,
textColor=colors.HexColor('#2c5282')
))
# Section header
styles.add(ParagraphStyle(
name='SectionHeader',
parent=styles['Heading3'],
fontSize=11,
spaceAfter=6,
spaceBefore=12,
textColor=colors.HexColor('#2d3748'),
borderColor=colors.HexColor('#e2e8f0'),
borderWidth=0,
borderPadding=0
))
# Body text
styles.add(ParagraphStyle(
name='GutachtenBody',
parent=styles['Normal'],
fontSize=10,
leading=14,
alignment=TA_JUSTIFY,
spaceAfter=6
))
# Small text for footer/meta
styles.add(ParagraphStyle(
name='MetaText',
parent=styles['Normal'],
fontSize=8,
textColor=colors.grey,
alignment=TA_LEFT
))
# List item
styles.add(ParagraphStyle(
name='ListItem',
parent=styles['Normal'],
fontSize=10,
leftIndent=20,
bulletIndent=10,
spaceAfter=4
))
return styles
@@ -0,0 +1,164 @@
"""
PDF Extraction Module
NOTE: This module delegates ML-heavy operations to the embedding-service via HTTP.
Provides enhanced PDF text extraction using multiple backends (in embedding-service):
1. Unstructured.io - Best for complex layouts, tables, headers (Apache 2.0)
2. pypdf - Modern, BSD-licensed PDF library (recommended default)
License Compliance:
- Default backends (unstructured, pypdf) are BSD/Apache licensed
"""
import os
import logging
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# Configuration (for backward compatibility - actual config in embedding-service)
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
PDF_BACKEND = os.getenv("PDF_EXTRACTION_BACKEND", "auto")
class PDFExtractionError(Exception):
"""Error during PDF extraction."""
pass
class PDFExtractionResult:
"""Result of PDF extraction with metadata."""
def __init__(
self,
text: str,
backend_used: str,
pages: int = 0,
elements: Optional[List[Dict]] = None,
tables: Optional[List[Dict]] = None,
metadata: Optional[Dict] = None,
):
self.text = text
self.backend_used = backend_used
self.pages = pages
self.elements = elements or []
self.tables = tables or []
self.metadata = metadata or {}
def to_dict(self) -> Dict:
return {
"text": self.text,
"backend_used": self.backend_used,
"pages": self.pages,
"element_count": len(self.elements),
"table_count": len(self.tables),
"metadata": self.metadata,
}
def _detect_available_backends() -> List[str]:
"""Get available backends from embedding-service."""
import httpx
try:
with httpx.Client(timeout=5.0) as client:
response = client.get(f"{EMBEDDING_SERVICE_URL}/models")
if response.status_code == 200:
data = response.json()
return data.get("available_pdf_backends", ["pypdf"])
except Exception as e:
logger.warning(f"Could not reach embedding-service: {e}")
return []
def extract_text_from_pdf_enhanced(
pdf_content: bytes,
backend: str = PDF_BACKEND,
fallback: bool = True,
) -> PDFExtractionResult:
"""
Extract text from PDF using embedding-service.
Args:
pdf_content: PDF file content as bytes
backend: Preferred backend (auto, unstructured, pypdf)
fallback: If True, try other backends if preferred fails
Returns:
PDFExtractionResult with extracted text and metadata
"""
import httpx
try:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
content=pdf_content,
headers={"Content-Type": "application/octet-stream"}
)
response.raise_for_status()
data = response.json()
return PDFExtractionResult(
text=data.get("text", ""),
backend_used=data.get("backend_used", "unknown"),
pages=data.get("pages", 0),
tables=[{"count": data.get("table_count", 0)}] if data.get("table_count", 0) > 0 else [],
metadata={"embedding_service": True}
)
except httpx.TimeoutException:
raise PDFExtractionError("PDF extraction timeout")
except httpx.HTTPStatusError as e:
raise PDFExtractionError(f"PDF extraction error: {e.response.status_code}")
except Exception as e:
raise PDFExtractionError(f"Failed to extract PDF: {str(e)}")
def extract_text_from_pdf(pdf_content: bytes) -> str:
"""
Extract text from PDF (simple interface).
This is a drop-in replacement for the original function
that uses the embedding-service internally.
"""
result = extract_text_from_pdf_enhanced(pdf_content)
return result.text
def get_pdf_extraction_info() -> dict:
"""Get information about PDF extraction configuration."""
import httpx
try:
with httpx.Client(timeout=5.0) as client:
response = client.get(f"{EMBEDDING_SERVICE_URL}/models")
if response.status_code == 200:
data = response.json()
available = data.get("available_pdf_backends", [])
return {
"configured_backend": data.get("pdf_backend", PDF_BACKEND),
"available_backends": available,
"recommended": "unstructured" if "unstructured" in available else "pypdf",
"backend_licenses": {
"unstructured": "Apache-2.0",
"pypdf": "BSD-3-Clause",
},
"commercial_safe_backends": available,
"embedding_service_url": EMBEDDING_SERVICE_URL,
"embedding_service_available": True,
}
except Exception as e:
logger.warning(f"Could not reach embedding-service: {e}")
# Fallback when embedding-service is not available
return {
"configured_backend": PDF_BACKEND,
"available_backends": [],
"recommended": None,
"backend_licenses": {},
"commercial_safe_backends": [],
"embedding_service_url": EMBEDDING_SERVICE_URL,
"embedding_service_available": False,
}
@@ -0,0 +1,6 @@
"""
metrics package — PostgreSQL metrics database operations.
Backward-compatible re-exports: consumers can still use
``from metrics_db import ...`` etc. via the shim files in backend/.
"""
+36
View File
@@ -0,0 +1,36 @@
"""
PostgreSQL Metrics Database Service — Barrel Re-export
Split into:
- metrics_db_core.py — Pool, feedback, metrics, relevance
- metrics_db_schema.py — Table initialization (DDL)
- metrics_db_zeugnis.py — Zeugnis source/document/stats operations
All public names are re-exported here for backward compatibility.
"""
# Schema: table initialization
from .db_schema import init_metrics_tables # noqa: F401
# Core: pool, feedback, search logs, metrics, relevance
from .db_core import ( # noqa: F401
DATABASE_URL,
get_pool,
store_feedback,
log_search,
log_upload,
calculate_metrics,
get_recent_feedback,
get_upload_history,
store_relevance_judgment,
calculate_precision_recall,
)
# Zeugnis operations
from .db_zeugnis import ( # noqa: F401
get_zeugnis_sources,
upsert_zeugnis_source,
get_zeugnis_documents,
get_zeugnis_stats,
log_zeugnis_event,
)
+459
View File
@@ -0,0 +1,459 @@
"""
PostgreSQL Metrics Database - Core Operations
Connection pool, table initialization, feedback storage, search logging,
upload history, metrics calculation, and relevance judgments.
Extracted from metrics_db.py to keep files under 500 LOC.
"""
import os
from typing import Optional, List, Dict
from datetime import datetime, timedelta
# Database Configuration - uses test default if not configured (for CI)
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
# Connection pool
_pool = None
async def get_pool():
"""Get or create database connection pool."""
global _pool
if _pool is None:
try:
import asyncpg
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
except ImportError:
print("Warning: asyncpg not installed. Metrics storage disabled.")
return None
except Exception as e:
print(f"Warning: Failed to connect to PostgreSQL: {e}")
return None
return _pool
# =============================================================================
# Feedback Storage
# =============================================================================
async def store_feedback(
result_id: str,
rating: int,
query_text: Optional[str] = None,
collection_name: Optional[str] = None,
score: Optional[float] = None,
notes: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store search result feedback."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_feedback
(result_id, query_text, collection_name, score, rating, notes, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
result_id, query_text, collection_name, score, rating, notes, user_id
)
return True
except Exception as e:
print(f"Failed to store feedback: {e}")
return False
async def log_search(
query_text: str,
collection_name: str,
result_count: int,
latency_ms: int,
top_score: Optional[float] = None,
filters: Optional[Dict] = None,
) -> bool:
"""Log a search for metrics tracking."""
pool = await get_pool()
if pool is None:
return False
try:
import json
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_logs
(query_text, collection_name, result_count, latency_ms, top_score, filters)
VALUES ($1, $2, $3, $4, $5, $6)
""",
query_text, collection_name, result_count, latency_ms, top_score,
json.dumps(filters) if filters else None
)
return True
except Exception as e:
print(f"Failed to log search: {e}")
return False
async def log_upload(
filename: str,
collection_name: str,
year: int,
pdfs_extracted: int,
minio_path: Optional[str] = None,
uploaded_by: Optional[str] = None,
) -> bool:
"""Log an upload for history tracking."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_upload_history
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
VALUES ($1, $2, $3, $4, $5, $6)
""",
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
)
return True
except Exception as e:
print(f"Failed to log upload: {e}")
return False
# =============================================================================
# Metrics Calculation
# =============================================================================
async def calculate_metrics(
collection_name: Optional[str] = None,
days: int = 7,
) -> Dict:
"""
Calculate RAG quality metrics from stored feedback.
Returns:
Dict with precision, recall, MRR, latency, etc.
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since]
if collection_name:
collection_filter = "AND collection_name = $2"
params.append(collection_name)
total_feedback = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
rating_dist = await conn.fetch(
f"""
SELECT rating, COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
GROUP BY rating
ORDER BY rating DESC
""",
*params
)
avg_rating = await conn.fetchval(
f"""
SELECT AVG(rating) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
score_dist = await conn.fetch(
f"""
SELECT
CASE
WHEN score >= 0.9 THEN '0.9+'
WHEN score >= 0.7 THEN '0.7-0.9'
WHEN score >= 0.5 THEN '0.5-0.7'
ELSE '<0.5'
END as range,
COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
GROUP BY range
ORDER BY range DESC
""",
*params
)
latency_stats = await conn.fetchrow(
f"""
SELECT
AVG(latency_ms) as avg_latency,
COUNT(*) as total_searches,
AVG(result_count) as avg_results
FROM rag_search_logs
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
""",
*params
)
precision_at_5 = await conn.fetchval(
f"""
SELECT
CASE WHEN COUNT(*) > 0
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
ELSE 0 END
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
) or 0
mrr = (avg_rating or 0) / 5.0
error_count = sum(
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
)
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
total_scored = sum(s['count'] for s in score_dist)
score_distribution = {}
for s in score_dist:
if total_scored > 0:
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
else:
score_distribution[s['range']] = 0
return {
"connected": True,
"period_days": days,
"precision_at_5": round(precision_at_5, 2),
"recall_at_10": round(precision_at_5 * 1.1, 2),
"mrr": round(mrr, 2),
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
"total_ratings": total_feedback,
"total_searches": latency_stats['total_searches'] or 0,
"error_rate": round(error_rate, 1),
"score_distribution": score_distribution,
"rating_distribution": {
str(r['rating']): r['count'] for r in rating_dist if r['rating']
},
}
except Exception as e:
print(f"Failed to calculate metrics: {e}")
return {"error": str(e), "connected": False}
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
"""Get recent feedback entries."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
FROM rag_search_feedback
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"result_id": r['result_id'],
"rating": r['rating'],
"query_text": r['query_text'],
"collection_name": r['collection_name'],
"score": r['score'],
"notes": r['notes'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get recent feedback: {e}")
return []
async def get_upload_history(limit: int = 20) -> List[Dict]:
"""Get recent upload history."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
FROM rag_upload_history
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"filename": r['filename'],
"collection_name": r['collection_name'],
"year": r['year'],
"pdfs_extracted": r['pdfs_extracted'],
"minio_path": r['minio_path'],
"uploaded_by": r['uploaded_by'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get upload history: {e}")
return []
# =============================================================================
# Relevance Judgments (Binary Precision/Recall)
# =============================================================================
async def store_relevance_judgment(
query_id: str,
query_text: str,
result_id: str,
is_relevant: bool,
result_rank: Optional[int] = None,
collection_name: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store binary relevance judgment for Precision/Recall calculation."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_relevance_judgments
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT DO NOTHING
""",
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
)
return True
except Exception as e:
print(f"Failed to store relevance judgment: {e}")
return False
async def calculate_precision_recall(
collection_name: Optional[str] = None,
days: int = 7,
k: int = 10,
) -> Dict:
"""
Calculate true Precision@k and Recall@k from binary relevance judgments.
Precision@k = (Relevant docs in top k) / k
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since, k]
if collection_name:
collection_filter = "AND collection_name = $3"
params.append(collection_name)
precision_result = await conn.fetchval(
f"""
WITH query_precision AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
GREATEST(COUNT(*), 1) as precision
FROM rag_relevance_judgments
WHERE created_at >= $1
AND (result_rank IS NULL OR result_rank <= $2)
{collection_filter}
GROUP BY query_id
)
SELECT AVG(precision) FROM query_precision
""",
*params
) or 0
recall_result = await conn.fetchval(
f"""
WITH query_recall AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
FROM rag_relevance_judgments
WHERE created_at >= $1
{collection_filter}
GROUP BY query_id
)
SELECT AVG(recall) FROM query_recall
""",
*params
) or 0
total_judgments = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
unique_queries = await conn.fetchval(
f"""
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
return {
"connected": True,
"period_days": days,
"k": k,
"precision_at_k": round(precision_result, 3),
"recall_at_k": round(recall_result, 3),
"f1_score": round(
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
),
"total_judgments": total_judgments or 0,
"unique_queries": unique_queries or 0,
}
except Exception as e:
print(f"Failed to calculate precision/recall: {e}")
return {"error": str(e), "connected": False}
@@ -0,0 +1,182 @@
"""
PostgreSQL Metrics Database - Schema Initialization
Table creation DDL for all metrics, feedback, and zeugnis tables.
Extracted from metrics_db_core.py to keep files under 500 LOC.
"""
from .db_core import get_pool
async def init_metrics_tables() -> bool:
"""Initialize metrics tables in PostgreSQL."""
pool = await get_pool()
if pool is None:
return False
create_tables_sql = """
-- RAG Search Feedback Table
CREATE TABLE IF NOT EXISTS rag_search_feedback (
id SERIAL PRIMARY KEY,
result_id VARCHAR(255) NOT NULL,
query_text TEXT,
collection_name VARCHAR(100),
score FLOAT,
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
notes TEXT,
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
-- Index for efficient querying
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
-- RAG Search Logs Table (for latency tracking)
CREATE TABLE IF NOT EXISTS rag_search_logs (
id SERIAL PRIMARY KEY,
query_text TEXT NOT NULL,
collection_name VARCHAR(100),
result_count INTEGER,
latency_ms INTEGER,
top_score FLOAT,
filters JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
-- RAG Upload History Table
CREATE TABLE IF NOT EXISTS rag_upload_history (
id SERIAL PRIMARY KEY,
filename VARCHAR(500) NOT NULL,
collection_name VARCHAR(100),
year INTEGER,
pdfs_extracted INTEGER,
minio_path VARCHAR(1000),
uploaded_by VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
-- Binaere Relevanz-Judgments fuer echte Precision/Recall
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
id SERIAL PRIMARY KEY,
query_id VARCHAR(255) NOT NULL,
query_text TEXT NOT NULL,
result_id VARCHAR(255) NOT NULL,
result_rank INTEGER,
is_relevant BOOLEAN NOT NULL,
collection_name VARCHAR(100),
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
-- Zeugnisse Source Tracking
CREATE TABLE IF NOT EXISTS zeugnis_sources (
id VARCHAR(36) PRIMARY KEY,
bundesland VARCHAR(10) NOT NULL,
name VARCHAR(255) NOT NULL,
base_url TEXT,
license_type VARCHAR(50) NOT NULL,
training_allowed BOOLEAN DEFAULT FALSE,
verified_by VARCHAR(100),
verified_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
-- Zeugnisse Seed URLs
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
url TEXT NOT NULL,
doc_type VARCHAR(50),
status VARCHAR(20) DEFAULT 'pending',
last_crawled TIMESTAMP,
error_message TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
-- Zeugnisse Documents
CREATE TABLE IF NOT EXISTS zeugnis_documents (
id VARCHAR(36) PRIMARY KEY,
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
title VARCHAR(500),
url TEXT NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
training_allowed BOOLEAN DEFAULT FALSE,
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
file_size INTEGER,
content_type VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
-- Zeugnisse Document Versions
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
version INTEGER NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
change_summary TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
-- Zeugnisse Usage Events (Audit Trail)
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
event_type VARCHAR(50) NOT NULL,
user_id VARCHAR(100),
details JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
-- Crawler Queue
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
priority INTEGER DEFAULT 5,
status VARCHAR(20) DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
documents_found INTEGER DEFAULT 0,
documents_indexed INTEGER DEFAULT 0,
error_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
"""
try:
async with pool.acquire() as conn:
await conn.execute(create_tables_sql)
print("RAG metrics tables initialized")
return True
except Exception as e:
print(f"Failed to initialize metrics tables: {e}")
return False
@@ -0,0 +1,193 @@
"""
PostgreSQL Metrics Database - Zeugnis Operations
Zeugnis source management, document queries, statistics, and event logging.
Extracted from metrics_db.py to keep files under 500 LOC.
"""
from typing import Optional, List, Dict
from .db_core import get_pool
# =============================================================================
# Zeugnis Database Operations
# =============================================================================
async def get_zeugnis_sources() -> List[Dict]:
"""Get all zeugnis sources (Bundeslaender)."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, bundesland, name, base_url, license_type, training_allowed,
verified_by, verified_at, created_at, updated_at
FROM zeugnis_sources
ORDER BY bundesland
"""
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis sources: {e}")
return []
async def upsert_zeugnis_source(
id: str,
bundesland: str,
name: str,
license_type: str,
training_allowed: bool,
base_url: Optional[str] = None,
verified_by: Optional[str] = None,
) -> bool:
"""Insert or update a zeugnis source."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
base_url = EXCLUDED.base_url,
license_type = EXCLUDED.license_type,
training_allowed = EXCLUDED.training_allowed,
verified_by = EXCLUDED.verified_by,
verified_at = NOW(),
updated_at = NOW()
""",
id, bundesland, name, base_url, license_type, training_allowed, verified_by
)
return True
except Exception as e:
print(f"Failed to upsert zeugnis source: {e}")
return False
async def get_zeugnis_documents(
bundesland: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
"""Get zeugnis documents with optional filtering."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
if bundesland:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE s.bundesland = $1
ORDER BY d.created_at DESC
LIMIT $2 OFFSET $3
""",
bundesland, limit, offset
)
else:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
ORDER BY d.created_at DESC
LIMIT $1 OFFSET $2
""",
limit, offset
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis documents: {e}")
return []
async def get_zeugnis_stats() -> Dict:
"""Get zeugnis crawler statistics."""
pool = await get_pool()
if pool is None:
return {"error": "Database not available"}
try:
async with pool.acquire() as conn:
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
indexed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
)
training_allowed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
)
per_bundesland = await conn.fetch(
"""
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
GROUP BY s.bundesland, s.name, s.training_allowed
ORDER BY s.bundesland
"""
)
active_crawls = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
)
return {
"total_sources": sources or 0,
"total_documents": documents or 0,
"indexed_documents": indexed or 0,
"training_allowed_documents": training_allowed or 0,
"active_crawls": active_crawls or 0,
"per_bundesland": [dict(r) for r in per_bundesland],
}
except Exception as e:
print(f"Failed to get zeugnis stats: {e}")
return {"error": str(e)}
async def log_zeugnis_event(
document_id: str,
event_type: str,
user_id: Optional[str] = None,
details: Optional[Dict] = None,
) -> bool:
"""Log a zeugnis usage event for audit trail."""
pool = await get_pool()
if pool is None:
return False
try:
import json
import uuid
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
VALUES ($1, $2, $3, $4, $5)
""",
str(uuid.uuid4()), document_id, event_type, user_id,
json.dumps(details) if details else None
)
return True
except Exception as e:
print(f"Failed to log zeugnis event: {e}")
return False
+4 -36
View File
@@ -1,36 +1,4 @@
"""
PostgreSQL Metrics Database Service — Barrel Re-export
Split into:
- metrics_db_core.py — Pool, feedback, metrics, relevance
- metrics_db_schema.py — Table initialization (DDL)
- metrics_db_zeugnis.py — Zeugnis source/document/stats operations
All public names are re-exported here for backward compatibility.
"""
# Schema: table initialization
from metrics_db_schema import init_metrics_tables # noqa: F401
# Core: pool, feedback, search logs, metrics, relevance
from metrics_db_core import ( # noqa: F401
DATABASE_URL,
get_pool,
store_feedback,
log_search,
log_upload,
calculate_metrics,
get_recent_feedback,
get_upload_history,
store_relevance_judgment,
calculate_precision_recall,
)
# Zeugnis operations
from metrics_db_zeugnis import ( # noqa: F401
get_zeugnis_sources,
upsert_zeugnis_source,
get_zeugnis_documents,
get_zeugnis_stats,
log_zeugnis_event,
)
# Backward-compat shim -- module moved to metrics/db.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("metrics.db")
+4 -459
View File
@@ -1,459 +1,4 @@
"""
PostgreSQL Metrics Database - Core Operations
Connection pool, table initialization, feedback storage, search logging,
upload history, metrics calculation, and relevance judgments.
Extracted from metrics_db.py to keep files under 500 LOC.
"""
import os
from typing import Optional, List, Dict
from datetime import datetime, timedelta
# Database Configuration - uses test default if not configured (for CI)
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://test:test@localhost:5432/test_metrics")
# Connection pool
_pool = None
async def get_pool():
"""Get or create database connection pool."""
global _pool
if _pool is None:
try:
import asyncpg
_pool = await asyncpg.create_pool(DATABASE_URL, min_size=2, max_size=10)
except ImportError:
print("Warning: asyncpg not installed. Metrics storage disabled.")
return None
except Exception as e:
print(f"Warning: Failed to connect to PostgreSQL: {e}")
return None
return _pool
# =============================================================================
# Feedback Storage
# =============================================================================
async def store_feedback(
result_id: str,
rating: int,
query_text: Optional[str] = None,
collection_name: Optional[str] = None,
score: Optional[float] = None,
notes: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store search result feedback."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_feedback
(result_id, query_text, collection_name, score, rating, notes, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
result_id, query_text, collection_name, score, rating, notes, user_id
)
return True
except Exception as e:
print(f"Failed to store feedback: {e}")
return False
async def log_search(
query_text: str,
collection_name: str,
result_count: int,
latency_ms: int,
top_score: Optional[float] = None,
filters: Optional[Dict] = None,
) -> bool:
"""Log a search for metrics tracking."""
pool = await get_pool()
if pool is None:
return False
try:
import json
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_search_logs
(query_text, collection_name, result_count, latency_ms, top_score, filters)
VALUES ($1, $2, $3, $4, $5, $6)
""",
query_text, collection_name, result_count, latency_ms, top_score,
json.dumps(filters) if filters else None
)
return True
except Exception as e:
print(f"Failed to log search: {e}")
return False
async def log_upload(
filename: str,
collection_name: str,
year: int,
pdfs_extracted: int,
minio_path: Optional[str] = None,
uploaded_by: Optional[str] = None,
) -> bool:
"""Log an upload for history tracking."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_upload_history
(filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by)
VALUES ($1, $2, $3, $4, $5, $6)
""",
filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by
)
return True
except Exception as e:
print(f"Failed to log upload: {e}")
return False
# =============================================================================
# Metrics Calculation
# =============================================================================
async def calculate_metrics(
collection_name: Optional[str] = None,
days: int = 7,
) -> Dict:
"""
Calculate RAG quality metrics from stored feedback.
Returns:
Dict with precision, recall, MRR, latency, etc.
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since]
if collection_name:
collection_filter = "AND collection_name = $2"
params.append(collection_name)
total_feedback = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
rating_dist = await conn.fetch(
f"""
SELECT rating, COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
GROUP BY rating
ORDER BY rating DESC
""",
*params
)
avg_rating = await conn.fetchval(
f"""
SELECT AVG(rating) FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
)
score_dist = await conn.fetch(
f"""
SELECT
CASE
WHEN score >= 0.9 THEN '0.9+'
WHEN score >= 0.7 THEN '0.7-0.9'
WHEN score >= 0.5 THEN '0.5-0.7'
ELSE '<0.5'
END as range,
COUNT(*) as count
FROM rag_search_feedback
WHERE created_at >= $1 AND score IS NOT NULL {collection_filter}
GROUP BY range
ORDER BY range DESC
""",
*params
)
latency_stats = await conn.fetchrow(
f"""
SELECT
AVG(latency_ms) as avg_latency,
COUNT(*) as total_searches,
AVG(result_count) as avg_results
FROM rag_search_logs
WHERE created_at >= $1 {collection_filter.replace('collection_name', 'collection_name')}
""",
*params
)
precision_at_5 = await conn.fetchval(
f"""
SELECT
CASE WHEN COUNT(*) > 0
THEN CAST(SUM(CASE WHEN rating >= 4 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*)
ELSE 0 END
FROM rag_search_feedback
WHERE created_at >= $1 {collection_filter}
""",
*params
) or 0
mrr = (avg_rating or 0) / 5.0
error_count = sum(
r['count'] for r in rating_dist if r['rating'] and r['rating'] <= 2
)
error_rate = (error_count / total_feedback * 100) if total_feedback > 0 else 0
total_scored = sum(s['count'] for s in score_dist)
score_distribution = {}
for s in score_dist:
if total_scored > 0:
score_distribution[s['range']] = round(s['count'] / total_scored * 100)
else:
score_distribution[s['range']] = 0
return {
"connected": True,
"period_days": days,
"precision_at_5": round(precision_at_5, 2),
"recall_at_10": round(precision_at_5 * 1.1, 2),
"mrr": round(mrr, 2),
"avg_latency_ms": round(latency_stats['avg_latency'] or 0),
"total_ratings": total_feedback,
"total_searches": latency_stats['total_searches'] or 0,
"error_rate": round(error_rate, 1),
"score_distribution": score_distribution,
"rating_distribution": {
str(r['rating']): r['count'] for r in rating_dist if r['rating']
},
}
except Exception as e:
print(f"Failed to calculate metrics: {e}")
return {"error": str(e), "connected": False}
async def get_recent_feedback(limit: int = 20) -> List[Dict]:
"""Get recent feedback entries."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT result_id, rating, query_text, collection_name, score, notes, created_at
FROM rag_search_feedback
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"result_id": r['result_id'],
"rating": r['rating'],
"query_text": r['query_text'],
"collection_name": r['collection_name'],
"score": r['score'],
"notes": r['notes'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get recent feedback: {e}")
return []
async def get_upload_history(limit: int = 20) -> List[Dict]:
"""Get recent upload history."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT filename, collection_name, year, pdfs_extracted, minio_path, uploaded_by, created_at
FROM rag_upload_history
ORDER BY created_at DESC
LIMIT $1
""",
limit
)
return [
{
"filename": r['filename'],
"collection_name": r['collection_name'],
"year": r['year'],
"pdfs_extracted": r['pdfs_extracted'],
"minio_path": r['minio_path'],
"uploaded_by": r['uploaded_by'],
"created_at": r['created_at'].isoformat() if r['created_at'] else None,
}
for r in rows
]
except Exception as e:
print(f"Failed to get upload history: {e}")
return []
# =============================================================================
# Relevance Judgments (Binary Precision/Recall)
# =============================================================================
async def store_relevance_judgment(
query_id: str,
query_text: str,
result_id: str,
is_relevant: bool,
result_rank: Optional[int] = None,
collection_name: Optional[str] = None,
user_id: Optional[str] = None,
) -> bool:
"""Store binary relevance judgment for Precision/Recall calculation."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO rag_relevance_judgments
(query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT DO NOTHING
""",
query_id, query_text, result_id, result_rank, is_relevant, collection_name, user_id
)
return True
except Exception as e:
print(f"Failed to store relevance judgment: {e}")
return False
async def calculate_precision_recall(
collection_name: Optional[str] = None,
days: int = 7,
k: int = 10,
) -> Dict:
"""
Calculate true Precision@k and Recall@k from binary relevance judgments.
Precision@k = (Relevant docs in top k) / k
Recall@k = (Relevant docs in top k) / (Total relevant docs for query)
"""
pool = await get_pool()
if pool is None:
return {"error": "Database not available", "connected": False}
try:
async with pool.acquire() as conn:
since = datetime.now() - timedelta(days=days)
collection_filter = ""
params = [since, k]
if collection_name:
collection_filter = "AND collection_name = $3"
params.append(collection_name)
precision_result = await conn.fetchval(
f"""
WITH query_precision AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant THEN 1 END)::FLOAT /
GREATEST(COUNT(*), 1) as precision
FROM rag_relevance_judgments
WHERE created_at >= $1
AND (result_rank IS NULL OR result_rank <= $2)
{collection_filter}
GROUP BY query_id
)
SELECT AVG(precision) FROM query_precision
""",
*params
) or 0
recall_result = await conn.fetchval(
f"""
WITH query_recall AS (
SELECT
query_id,
COUNT(CASE WHEN is_relevant AND (result_rank IS NULL OR result_rank <= $2) THEN 1 END)::FLOAT /
GREATEST(COUNT(CASE WHEN is_relevant THEN 1 END), 1) as recall
FROM rag_relevance_judgments
WHERE created_at >= $1
{collection_filter}
GROUP BY query_id
)
SELECT AVG(recall) FROM query_recall
""",
*params
) or 0
total_judgments = await conn.fetchval(
f"""
SELECT COUNT(*) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
unique_queries = await conn.fetchval(
f"""
SELECT COUNT(DISTINCT query_id) FROM rag_relevance_judgments
WHERE created_at >= $1 {collection_filter}
""",
since, *([collection_name] if collection_name else [])
)
return {
"connected": True,
"period_days": days,
"k": k,
"precision_at_k": round(precision_result, 3),
"recall_at_k": round(recall_result, 3),
"f1_score": round(
2 * precision_result * recall_result / max(precision_result + recall_result, 0.001), 3
),
"total_judgments": total_judgments or 0,
"unique_queries": unique_queries or 0,
}
except Exception as e:
print(f"Failed to calculate precision/recall: {e}")
return {"error": str(e), "connected": False}
# Backward-compat shim -- module moved to metrics/db_core.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("metrics.db_core")
+4 -182
View File
@@ -1,182 +1,4 @@
"""
PostgreSQL Metrics Database - Schema Initialization
Table creation DDL for all metrics, feedback, and zeugnis tables.
Extracted from metrics_db_core.py to keep files under 500 LOC.
"""
from metrics_db_core import get_pool
async def init_metrics_tables() -> bool:
"""Initialize metrics tables in PostgreSQL."""
pool = await get_pool()
if pool is None:
return False
create_tables_sql = """
-- RAG Search Feedback Table
CREATE TABLE IF NOT EXISTS rag_search_feedback (
id SERIAL PRIMARY KEY,
result_id VARCHAR(255) NOT NULL,
query_text TEXT,
collection_name VARCHAR(100),
score FLOAT,
rating INTEGER CHECK (rating >= 1 AND rating <= 5),
notes TEXT,
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
-- Index for efficient querying
CREATE INDEX IF NOT EXISTS idx_feedback_created_at ON rag_search_feedback(created_at);
CREATE INDEX IF NOT EXISTS idx_feedback_collection ON rag_search_feedback(collection_name);
CREATE INDEX IF NOT EXISTS idx_feedback_rating ON rag_search_feedback(rating);
-- RAG Search Logs Table (for latency tracking)
CREATE TABLE IF NOT EXISTS rag_search_logs (
id SERIAL PRIMARY KEY,
query_text TEXT NOT NULL,
collection_name VARCHAR(100),
result_count INTEGER,
latency_ms INTEGER,
top_score FLOAT,
filters JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_search_logs_created_at ON rag_search_logs(created_at);
-- RAG Upload History Table
CREATE TABLE IF NOT EXISTS rag_upload_history (
id SERIAL PRIMARY KEY,
filename VARCHAR(500) NOT NULL,
collection_name VARCHAR(100),
year INTEGER,
pdfs_extracted INTEGER,
minio_path VARCHAR(1000),
uploaded_by VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_upload_history_created_at ON rag_upload_history(created_at);
-- Binaere Relevanz-Judgments fuer echte Precision/Recall
CREATE TABLE IF NOT EXISTS rag_relevance_judgments (
id SERIAL PRIMARY KEY,
query_id VARCHAR(255) NOT NULL,
query_text TEXT NOT NULL,
result_id VARCHAR(255) NOT NULL,
result_rank INTEGER,
is_relevant BOOLEAN NOT NULL,
collection_name VARCHAR(100),
user_id VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_relevance_query ON rag_relevance_judgments(query_id);
CREATE INDEX IF NOT EXISTS idx_relevance_created_at ON rag_relevance_judgments(created_at);
-- Zeugnisse Source Tracking
CREATE TABLE IF NOT EXISTS zeugnis_sources (
id VARCHAR(36) PRIMARY KEY,
bundesland VARCHAR(10) NOT NULL,
name VARCHAR(255) NOT NULL,
base_url TEXT,
license_type VARCHAR(50) NOT NULL,
training_allowed BOOLEAN DEFAULT FALSE,
verified_by VARCHAR(100),
verified_at TIMESTAMP,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_sources_bundesland ON zeugnis_sources(bundesland);
-- Zeugnisse Seed URLs
CREATE TABLE IF NOT EXISTS zeugnis_seed_urls (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
url TEXT NOT NULL,
doc_type VARCHAR(50),
status VARCHAR(20) DEFAULT 'pending',
last_crawled TIMESTAMP,
error_message TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_source ON zeugnis_seed_urls(source_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_seed_urls_status ON zeugnis_seed_urls(status);
-- Zeugnisse Documents
CREATE TABLE IF NOT EXISTS zeugnis_documents (
id VARCHAR(36) PRIMARY KEY,
seed_url_id VARCHAR(36) REFERENCES zeugnis_seed_urls(id),
title VARCHAR(500),
url TEXT NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
training_allowed BOOLEAN DEFAULT FALSE,
indexed_in_qdrant BOOLEAN DEFAULT FALSE,
file_size INTEGER,
content_type VARCHAR(100),
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_seed ON zeugnis_documents(seed_url_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_documents_hash ON zeugnis_documents(content_hash);
-- Zeugnisse Document Versions
CREATE TABLE IF NOT EXISTS zeugnis_document_versions (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
version INTEGER NOT NULL,
content_hash VARCHAR(64),
minio_path TEXT,
change_summary TEXT,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_versions_doc ON zeugnis_document_versions(document_id);
-- Zeugnisse Usage Events (Audit Trail)
CREATE TABLE IF NOT EXISTS zeugnis_usage_events (
id VARCHAR(36) PRIMARY KEY,
document_id VARCHAR(36) REFERENCES zeugnis_documents(id),
event_type VARCHAR(50) NOT NULL,
user_id VARCHAR(100),
details JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_doc ON zeugnis_usage_events(document_id);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_type ON zeugnis_usage_events(event_type);
CREATE INDEX IF NOT EXISTS idx_zeugnis_events_created ON zeugnis_usage_events(created_at);
-- Crawler Queue
CREATE TABLE IF NOT EXISTS zeugnis_crawler_queue (
id VARCHAR(36) PRIMARY KEY,
source_id VARCHAR(36) REFERENCES zeugnis_sources(id),
priority INTEGER DEFAULT 5,
status VARCHAR(20) DEFAULT 'pending',
started_at TIMESTAMP,
completed_at TIMESTAMP,
documents_found INTEGER DEFAULT 0,
documents_indexed INTEGER DEFAULT 0,
error_count INTEGER DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_crawler_queue_status ON zeugnis_crawler_queue(status);
"""
try:
async with pool.acquire() as conn:
await conn.execute(create_tables_sql)
print("RAG metrics tables initialized")
return True
except Exception as e:
print(f"Failed to initialize metrics tables: {e}")
return False
# Backward-compat shim -- module moved to metrics/db_schema.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("metrics.db_schema")
+4 -193
View File
@@ -1,193 +1,4 @@
"""
PostgreSQL Metrics Database - Zeugnis Operations
Zeugnis source management, document queries, statistics, and event logging.
Extracted from metrics_db.py to keep files under 500 LOC.
"""
from typing import Optional, List, Dict
from metrics_db_core import get_pool
# =============================================================================
# Zeugnis Database Operations
# =============================================================================
async def get_zeugnis_sources() -> List[Dict]:
"""Get all zeugnis sources (Bundeslaender)."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, bundesland, name, base_url, license_type, training_allowed,
verified_by, verified_at, created_at, updated_at
FROM zeugnis_sources
ORDER BY bundesland
"""
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis sources: {e}")
return []
async def upsert_zeugnis_source(
id: str,
bundesland: str,
name: str,
license_type: str,
training_allowed: bool,
base_url: Optional[str] = None,
verified_by: Optional[str] = None,
) -> bool:
"""Insert or update a zeugnis source."""
pool = await get_pool()
if pool is None:
return False
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_sources (id, bundesland, name, base_url, license_type, training_allowed, verified_by, verified_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (id) DO UPDATE SET
name = EXCLUDED.name,
base_url = EXCLUDED.base_url,
license_type = EXCLUDED.license_type,
training_allowed = EXCLUDED.training_allowed,
verified_by = EXCLUDED.verified_by,
verified_at = NOW(),
updated_at = NOW()
""",
id, bundesland, name, base_url, license_type, training_allowed, verified_by
)
return True
except Exception as e:
print(f"Failed to upsert zeugnis source: {e}")
return False
async def get_zeugnis_documents(
bundesland: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> List[Dict]:
"""Get zeugnis documents with optional filtering."""
pool = await get_pool()
if pool is None:
return []
try:
async with pool.acquire() as conn:
if bundesland:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE s.bundesland = $1
ORDER BY d.created_at DESC
LIMIT $2 OFFSET $3
""",
bundesland, limit, offset
)
else:
rows = await conn.fetch(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
ORDER BY d.created_at DESC
LIMIT $1 OFFSET $2
""",
limit, offset
)
return [dict(r) for r in rows]
except Exception as e:
print(f"Failed to get zeugnis documents: {e}")
return []
async def get_zeugnis_stats() -> Dict:
"""Get zeugnis crawler statistics."""
pool = await get_pool()
if pool is None:
return {"error": "Database not available"}
try:
async with pool.acquire() as conn:
sources = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_sources")
documents = await conn.fetchval("SELECT COUNT(*) FROM zeugnis_documents")
indexed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE indexed_in_qdrant = true"
)
training_allowed = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_documents WHERE training_allowed = true"
)
per_bundesland = await conn.fetch(
"""
SELECT s.bundesland, s.name, s.training_allowed, COUNT(d.id) as doc_count
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
GROUP BY s.bundesland, s.name, s.training_allowed
ORDER BY s.bundesland
"""
)
active_crawls = await conn.fetchval(
"SELECT COUNT(*) FROM zeugnis_crawler_queue WHERE status = 'running'"
)
return {
"total_sources": sources or 0,
"total_documents": documents or 0,
"indexed_documents": indexed or 0,
"training_allowed_documents": training_allowed or 0,
"active_crawls": active_crawls or 0,
"per_bundesland": [dict(r) for r in per_bundesland],
}
except Exception as e:
print(f"Failed to get zeugnis stats: {e}")
return {"error": str(e)}
async def log_zeugnis_event(
document_id: str,
event_type: str,
user_id: Optional[str] = None,
details: Optional[Dict] = None,
) -> bool:
"""Log a zeugnis usage event for audit trail."""
pool = await get_pool()
if pool is None:
return False
try:
import json
import uuid
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_usage_events (id, document_id, event_type, user_id, details)
VALUES ($1, $2, $3, $4, $5)
""",
str(uuid.uuid4()), document_id, event_type, user_id,
json.dumps(details) if details else None
)
return True
except Exception as e:
print(f"Failed to log zeugnis event: {e}")
return False
# Backward-compat shim -- module moved to metrics/db_zeugnis.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("metrics.db_zeugnis")
@@ -1,26 +1,4 @@
"""
NRU Worksheet Generator — barrel re-export.
All implementation split into:
nru_worksheet_models — data classes, entry separation
nru_worksheet_html — HTML generation
nru_worksheet_pdf — PDF generation
Per scanned page, we generate 2 worksheet pages.
"""
# Models
from nru_worksheet_models import ( # noqa: F401
VocabEntry,
SentenceEntry,
separate_vocab_and_sentences,
)
# HTML generation
from nru_worksheet_html import ( # noqa: F401
generate_nru_html,
generate_nru_worksheet_html,
)
# PDF generation
from nru_worksheet_pdf import generate_nru_pdf # noqa: F401
# Backward-compat shim -- module moved to worksheet/nru_generator.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.nru_generator")
+4 -466
View File
@@ -1,466 +1,4 @@
"""
NRU Worksheet HTML — HTML generation for vocabulary worksheets.
Extracted from nru_worksheet_generator.py for modularity.
"""
import logging
from typing import List, Dict
from nru_worksheet_models import VocabEntry, SentenceEntry, separate_vocab_and_sentences
logger = logging.getLogger(__name__)
def generate_nru_html(
vocab_list: List[VocabEntry],
sentence_list: List[SentenceEntry],
page_number: int,
title: str = "Vokabeltest",
show_solutions: bool = False,
line_height_px: int = 28
) -> str:
"""
Generate HTML for NRU-format worksheet.
Returns HTML for 2 pages:
- Page 1: Vocabulary table (3 columns)
- Page 2: Sentence practice (full width)
"""
# Filter by page
page_vocab = [v for v in vocab_list if v.source_page == page_number]
page_sentences = [s for s in sentence_list if s.source_page == page_number]
html = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
@page {{
size: A4;
margin: 1.5cm 2cm;
}}
* {{
box-sizing: border-box;
}}
body {{
font-family: Arial, Helvetica, sans-serif;
font-size: 12pt;
line-height: 1.4;
margin: 0;
padding: 0;
}}
.page {{
page-break-after: always;
min-height: 100%;
}}
.page:last-child {{
page-break-after: avoid;
}}
h1 {{
font-size: 16pt;
margin: 0 0 8px 0;
text-align: center;
}}
.header {{
margin-bottom: 15px;
}}
.name-line {{
font-size: 11pt;
margin-bottom: 10px;
}}
/* Vocabulary Table - 3 columns */
.vocab-table {{
width: 100%;
border-collapse: collapse;
table-layout: fixed;
}}
.vocab-table th {{
background: #f0f0f0;
border: 1px solid #333;
padding: 6px 8px;
font-weight: bold;
font-size: 11pt;
text-align: left;
}}
.vocab-table td {{
border: 1px solid #333;
padding: 4px 8px;
height: {line_height_px}px;
vertical-align: middle;
}}
.vocab-table .col-english {{ width: 35%; }}
.vocab-table .col-german {{ width: 35%; }}
.vocab-table .col-correction {{ width: 30%; }}
.vocab-answer {{
color: #0066cc;
font-style: italic;
}}
/* Sentence Table - full width */
.sentence-table {{
width: 100%;
border-collapse: collapse;
margin-bottom: 15px;
}}
.sentence-table td {{
border: 1px solid #333;
padding: 6px 10px;
}}
.sentence-header {{
background: #f5f5f5;
font-weight: normal;
min-height: 30px;
}}
.sentence-line {{
height: {line_height_px + 4}px;
}}
.sentence-answer {{
color: #0066cc;
font-style: italic;
font-size: 11pt;
}}
.page-info {{
font-size: 9pt;
color: #666;
text-align: right;
margin-top: 10px;
}}
</style>
</head>
<body>
"""
# ========== PAGE 1: VOCABULARY TABLE ==========
if page_vocab:
html += f"""
<div class="page">
<div class="header">
<h1>{title} - Vokabeln (Seite {page_number})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
<table class="vocab-table">
<thead>
<tr>
<th class="col-english">Englisch</th>
<th class="col-german">Deutsch</th>
<th class="col-correction">Korrektur</th>
</tr>
</thead>
<tbody>
"""
for v in page_vocab:
if show_solutions:
html += f"""
<tr>
<td>{v.english}</td>
<td class="vocab-answer">{v.german}</td>
<td></td>
</tr>
"""
else:
html += f"""
<tr>
<td>{v.english}</td>
<td></td>
<td></td>
</tr>
"""
html += """
</tbody>
</table>
<div class="page-info">Vokabeln aus Unit</div>
</div>
"""
# ========== PAGE 2: SENTENCE PRACTICE ==========
if page_sentences:
html += f"""
<div class="page">
<div class="header">
<h1>{title} - Lernsaetze (Seite {page_number})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
"""
for s in page_sentences:
html += f"""
<table class="sentence-table">
<tr>
<td class="sentence-header">{s.german}</td>
</tr>
"""
if show_solutions:
html += f"""
<tr>
<td class="sentence-line sentence-answer">{s.english}</td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
else:
html += """
<tr>
<td class="sentence-line"></td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
html += """
</table>
"""
html += """
<div class="page-info">Lernsaetze aus Unit</div>
</div>
"""
html += """
</body>
</html>
"""
return html
def generate_nru_worksheet_html(
entries: List[Dict],
title: str = "Vokabeltest",
show_solutions: bool = False,
specific_pages: List[int] = None
) -> str:
"""
Generate complete NRU worksheet HTML for all pages.
Args:
entries: List of vocabulary entries with source_page
title: Worksheet title
show_solutions: Whether to show answers
specific_pages: List of specific page numbers to include (1-indexed)
Returns:
Complete HTML document
"""
# Separate into vocab and sentences
vocab_list, sentence_list = separate_vocab_and_sentences(entries)
# Get unique page numbers
all_pages = set()
for v in vocab_list:
all_pages.add(v.source_page)
for s in sentence_list:
all_pages.add(s.source_page)
# Filter to specific pages if requested
if specific_pages:
all_pages = all_pages.intersection(set(specific_pages))
pages_sorted = sorted(all_pages)
logger.info(f"Generating NRU worksheet for pages {pages_sorted}")
logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}")
# Generate HTML for each page
combined_html = """<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
@page {
size: A4;
margin: 1.5cm 2cm;
}
* {
box-sizing: border-box;
}
body {
font-family: Arial, Helvetica, sans-serif;
font-size: 12pt;
line-height: 1.4;
margin: 0;
padding: 0;
}
.page {
page-break-after: always;
min-height: 100%;
}
.page:last-child {
page-break-after: avoid;
}
h1 {
font-size: 16pt;
margin: 0 0 8px 0;
text-align: center;
}
.header {
margin-bottom: 15px;
}
.name-line {
font-size: 11pt;
margin-bottom: 10px;
}
/* Vocabulary Table - 3 columns */
.vocab-table {
width: 100%;
border-collapse: collapse;
table-layout: fixed;
}
.vocab-table th {
background: #f0f0f0;
border: 1px solid #333;
padding: 6px 8px;
font-weight: bold;
font-size: 11pt;
text-align: left;
}
.vocab-table td {
border: 1px solid #333;
padding: 4px 8px;
height: 28px;
vertical-align: middle;
}
.vocab-table .col-english { width: 35%; }
.vocab-table .col-german { width: 35%; }
.vocab-table .col-correction { width: 30%; }
.vocab-answer {
color: #0066cc;
font-style: italic;
}
/* Sentence Table - full width */
.sentence-table {
width: 100%;
border-collapse: collapse;
margin-bottom: 15px;
}
.sentence-table td {
border: 1px solid #333;
padding: 6px 10px;
}
.sentence-header {
background: #f5f5f5;
font-weight: normal;
min-height: 30px;
}
.sentence-line {
height: 32px;
}
.sentence-answer {
color: #0066cc;
font-style: italic;
font-size: 11pt;
}
.page-info {
font-size: 9pt;
color: #666;
text-align: right;
margin-top: 10px;
}
</style>
</head>
<body>
"""
for page_num in pages_sorted:
page_vocab = [v for v in vocab_list if v.source_page == page_num]
page_sentences = [s for s in sentence_list if s.source_page == page_num]
# PAGE 1: VOCABULARY TABLE
if page_vocab:
combined_html += f"""
<div class="page">
<div class="header">
<h1>{title} - Vokabeln (Seite {page_num})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
<table class="vocab-table">
<thead>
<tr>
<th class="col-english">Englisch</th>
<th class="col-german">Deutsch</th>
<th class="col-correction">Korrektur</th>
</tr>
</thead>
<tbody>
"""
for v in page_vocab:
if show_solutions:
combined_html += f"""
<tr>
<td>{v.english}</td>
<td class="vocab-answer">{v.german}</td>
<td></td>
</tr>
"""
else:
combined_html += f"""
<tr>
<td>{v.english}</td>
<td></td>
<td></td>
</tr>
"""
combined_html += f"""
</tbody>
</table>
<div class="page-info">{title} - Seite {page_num}</div>
</div>
"""
# PAGE 2: SENTENCE PRACTICE
if page_sentences:
combined_html += f"""
<div class="page">
<div class="header">
<h1>{title} - Lernsaetze (Seite {page_num})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
"""
for s in page_sentences:
combined_html += f"""
<table class="sentence-table">
<tr>
<td class="sentence-header">{s.german}</td>
</tr>
"""
if show_solutions:
combined_html += f"""
<tr>
<td class="sentence-line sentence-answer">{s.english}</td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
else:
combined_html += """
<tr>
<td class="sentence-line"></td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
combined_html += """
</table>
"""
combined_html += f"""
<div class="page-info">{title} - Seite {page_num}</div>
</div>
"""
combined_html += """
</body>
</html>
"""
return combined_html
# Backward-compat shim -- module moved to worksheet/nru_html.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.nru_html")
@@ -1,70 +1,4 @@
"""
NRU Worksheet Models — data classes and entry separation logic.
Extracted from nru_worksheet_generator.py for modularity.
"""
import logging
from typing import List, Dict, Tuple
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class VocabEntry:
english: str
german: str
source_page: int = 1
@dataclass
class SentenceEntry:
german: str
english: str # For solution sheet
source_page: int = 1
def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]:
"""
Separate vocabulary entries into single words/phrases and full sentences.
Sentences are identified by:
- Ending with punctuation (. ! ?)
- Being longer than 40 characters
- Containing multiple words with capital letters mid-sentence
"""
vocab_list = []
sentence_list = []
for entry in entries:
english = entry.get("english", "").strip()
german = entry.get("german", "").strip()
source_page = entry.get("source_page", 1)
if not english or not german:
continue
# Detect if this is a sentence
is_sentence = (
english.endswith('.') or
english.endswith('!') or
english.endswith('?') or
len(english) > 50 or
(len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w))
)
if is_sentence:
sentence_list.append(SentenceEntry(
german=german,
english=english,
source_page=source_page
))
else:
vocab_list.append(VocabEntry(
english=english,
german=german,
source_page=source_page
))
return vocab_list, sentence_list
# Backward-compat shim -- module moved to worksheet/nru_models.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.nru_models")
+4 -31
View File
@@ -1,31 +1,4 @@
"""
NRU Worksheet PDF — PDF generation using weasyprint.
Extracted from nru_worksheet_generator.py for modularity.
"""
from typing import List, Dict, Tuple
from nru_worksheet_html import generate_nru_worksheet_html
async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]:
"""
Generate NRU worksheet PDFs.
Returns:
Tuple of (worksheet_pdf_bytes, solution_pdf_bytes)
"""
from weasyprint import HTML
# Generate worksheet HTML
worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False)
worksheet_pdf = HTML(string=worksheet_html).write_pdf()
# Generate solution HTML
solution_pdf = None
if include_solutions:
solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True)
solution_pdf = HTML(string=solution_html).write_pdf()
return worksheet_pdf, solution_pdf
# Backward-compat shim -- module moved to worksheet/nru_pdf.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.nru_pdf")
+4 -17
View File
@@ -1,17 +1,4 @@
"""
PDF Export Module for Abiturkorrektur System
Barrel re-export: all PDF generation functions and constants.
"""
from pdf_export_styles import ( # noqa: F401
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
CRITERIA_WEIGHTS,
get_custom_styles,
)
from pdf_export_gutachten import generate_gutachten_pdf # noqa: F401
from pdf_export_overview import ( # noqa: F401
generate_klausur_overview_pdf,
generate_annotations_pdf,
)
# Backward-compat shim -- module moved to korrektur/pdf_export.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.pdf_export")
+4 -315
View File
@@ -1,315 +1,4 @@
"""
PDF Export - Individual Gutachten PDF generation.
Generates a single student's Gutachten with criteria table,
workflow info, and annotation summary.
"""
import io
from datetime import datetime
from typing import Dict, List, Optional, Any
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import cm
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
HRFlowable, KeepTogether
)
from pdf_export_styles import (
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
CRITERIA_WEIGHTS,
get_custom_styles,
)
def generate_gutachten_pdf(
student_data: Dict[str, Any],
klausur_data: Dict[str, Any],
annotations: List[Dict[str, Any]] = None,
workflow_data: Dict[str, Any] = None
) -> bytes:
"""
Generate a PDF Gutachten for a single student.
Args:
student_data: Student work data including criteria_scores, gutachten, grade_points
klausur_data: Klausur metadata (title, subject, year, etc.)
annotations: List of annotations for annotation summary
workflow_data: Examiner workflow data (EK, ZK, DK info)
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=2*cm,
leftMargin=2*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Gutachten zur Abiturklausur", styles['GutachtenTitle']))
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
# Meta information table
meta_data = [
["Pruefling:", student_data.get('student_name', 'Anonym')],
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
["Kurs:", klausur_data.get('semester', 'Abitur')],
["Datum:", datetime.now().strftime("%d.%m.%Y")]
]
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
meta_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(meta_table)
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Gutachten content
_add_gutachten_content(story, styles, student_data)
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Bewertungstabelle
_add_criteria_table(story, styles, student_data)
# Final grade box
_add_grade_box(story, styles, student_data)
# Examiner workflow information
if workflow_data:
_add_workflow_info(story, styles, workflow_data)
# Annotation summary
if annotations:
_add_annotation_summary(story, styles, annotations)
# Footer
_add_footer(story, styles)
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
def _add_gutachten_content(story, styles, student_data):
"""Add gutachten text sections to the story."""
gutachten = student_data.get('gutachten', {})
if gutachten:
if gutachten.get('einleitung'):
story.append(Paragraph("Einleitung", styles['SectionHeader']))
story.append(Paragraph(gutachten['einleitung'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('hauptteil'):
story.append(Paragraph("Hauptteil", styles['SectionHeader']))
story.append(Paragraph(gutachten['hauptteil'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('fazit'):
story.append(Paragraph("Fazit", styles['SectionHeader']))
story.append(Paragraph(gutachten['fazit'], styles['GutachtenBody']))
story.append(Spacer(1, 0.3*cm))
if gutachten.get('staerken') or gutachten.get('schwaechen'):
story.append(Spacer(1, 0.3*cm))
if gutachten.get('staerken'):
story.append(Paragraph("Staerken:", styles['SectionHeader']))
for s in gutachten['staerken']:
story.append(Paragraph(f"{s}", styles['ListItem']))
if gutachten.get('schwaechen'):
story.append(Paragraph("Verbesserungspotenzial:", styles['SectionHeader']))
for s in gutachten['schwaechen']:
story.append(Paragraph(f"{s}", styles['ListItem']))
else:
story.append(Paragraph("<i>Kein Gutachten-Text vorhanden.</i>", styles['GutachtenBody']))
def _add_criteria_table(story, styles, student_data):
"""Add criteria scoring table to the story."""
story.append(Paragraph("Bewertung nach Kriterien", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
criteria_scores = student_data.get('criteria_scores', {})
table_data = [["Kriterium", "Gewichtung", "Erreicht", "Punkte"]]
total_weighted = 0
total_weight = 0
for key, display_name in CRITERIA_DISPLAY_NAMES.items():
weight = CRITERIA_WEIGHTS.get(key, 0)
score_data = criteria_scores.get(key, {})
score = score_data.get('score', 0) if isinstance(score_data, dict) else score_data
weighted_score = (score / 100) * weight if score else 0
total_weighted += weighted_score
total_weight += weight
table_data.append([
display_name,
f"{weight}%",
f"{score}%",
f"{weighted_score:.1f}"
])
table_data.append([
"Gesamt",
f"{total_weight}%",
"",
f"{total_weighted:.1f}"
])
criteria_table = Table(table_data, colWidths=[8*cm, 2.5*cm, 2.5*cm, 2.5*cm])
criteria_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 10),
('ALIGN', (1, 0), (-1, -1), 'CENTER'),
('FONTSIZE', (0, 1), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('TOPPADDING', (0, 0), (-1, -1), 6),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
('BACKGROUND', (0, -1), (-1, -1), colors.HexColor('#f7fafc')),
('FONTNAME', (0, -1), (-1, -1), 'Helvetica-Bold'),
('ROWBACKGROUNDS', (0, 1), (-1, -2), [colors.white, colors.HexColor('#f7fafc')]),
]))
story.append(criteria_table)
story.append(Spacer(1, 0.5*cm))
def _add_grade_box(story, styles, student_data):
"""Add final grade box to the story."""
grade_points = student_data.get('grade_points', 0)
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "?")
raw_points = student_data.get('raw_points', 0)
grade_data = [
["Rohpunkte:", f"{raw_points} / 100"],
["Notenpunkte:", f"{grade_points} Punkte"],
["Note:", grade_note]
]
grade_table = Table(grade_data, colWidths=[4*cm, 4*cm])
grade_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#ebf8ff')),
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTNAME', (1, -1), (1, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 11),
('FONTSIZE', (1, -1), (1, -1), 14),
('TEXTCOLOR', (1, -1), (1, -1), colors.HexColor('#2c5282')),
('BOTTOMPADDING', (0, 0), (-1, -1), 8),
('TOPPADDING', (0, 0), (-1, -1), 8),
('LEFTPADDING', (0, 0), (-1, -1), 12),
('BOX', (0, 0), (-1, -1), 1, colors.HexColor('#2c5282')),
('ALIGN', (1, 0), (1, -1), 'RIGHT'),
]))
story.append(KeepTogether([
Paragraph("Endergebnis", styles['SectionHeader']),
Spacer(1, 0.2*cm),
grade_table
]))
def _add_workflow_info(story, styles, workflow_data):
"""Add examiner workflow information to the story."""
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.3*cm))
story.append(Paragraph("Korrekturverlauf", styles['SectionHeader']))
workflow_rows = []
if workflow_data.get('erst_korrektor'):
ek = workflow_data['erst_korrektor']
workflow_rows.append([
"Erstkorrektor:",
ek.get('name', 'Unbekannt'),
f"{ek.get('grade_points', '-')} Punkte"
])
if workflow_data.get('zweit_korrektor'):
zk = workflow_data['zweit_korrektor']
workflow_rows.append([
"Zweitkorrektor:",
zk.get('name', 'Unbekannt'),
f"{zk.get('grade_points', '-')} Punkte"
])
if workflow_data.get('dritt_korrektor'):
dk = workflow_data['dritt_korrektor']
workflow_rows.append([
"Drittkorrektor:",
dk.get('name', 'Unbekannt'),
f"{dk.get('grade_points', '-')} Punkte"
])
if workflow_data.get('final_grade_source'):
workflow_rows.append([
"Endnote durch:",
workflow_data['final_grade_source'],
""
])
if workflow_rows:
workflow_table = Table(workflow_rows, colWidths=[4*cm, 6*cm, 4*cm])
workflow_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(workflow_table)
def _add_annotation_summary(story, styles, annotations):
"""Add annotation summary to the story."""
story.append(Spacer(1, 0.5*cm))
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.3*cm))
story.append(Paragraph("Anmerkungen (Zusammenfassung)", styles['SectionHeader']))
by_type = {}
for ann in annotations:
ann_type = ann.get('type', 'comment')
if ann_type not in by_type:
by_type[ann_type] = []
by_type[ann_type].append(ann)
for ann_type, anns in by_type.items():
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
story.append(Paragraph(f"{type_name} ({len(anns)} Anmerkungen)", styles['ListItem']))
def _add_footer(story, styles):
"""Add generation footer to the story."""
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Backward-compat shim -- module moved to korrektur/pdf_export_gutachten.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.pdf_export_gutachten")
+4 -297
View File
@@ -1,297 +1,4 @@
"""
PDF Export - Klausur overview and annotations PDF generation.
Generates:
- Klausur overview with grade distribution for all students
- Annotations PDF for a single student
"""
import io
from datetime import datetime
from typing import Dict, List, Optional, Any
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import cm
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
HRFlowable
)
from pdf_export_styles import (
GRADE_POINTS_TO_NOTE,
CRITERIA_DISPLAY_NAMES,
get_custom_styles,
)
def generate_klausur_overview_pdf(
klausur_data: Dict[str, Any],
students: List[Dict[str, Any]],
fairness_data: Optional[Dict[str, Any]] = None
) -> bytes:
"""
Generate an overview PDF for an entire Klausur with all student grades.
Args:
klausur_data: Klausur metadata
students: List of all student work data
fairness_data: Optional fairness analysis data
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=1.5*cm,
leftMargin=1.5*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Notenuebersicht", styles['GutachtenTitle']))
story.append(Paragraph(f"{klausur_data.get('subject', 'Deutsch')} - {klausur_data.get('title', '')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
# Meta information
meta_data = [
["Schuljahr:", f"{klausur_data.get('year', 2025)}"],
["Kurs:", klausur_data.get('semester', 'Abitur')],
["Anzahl Arbeiten:", str(len(students))],
["Stand:", datetime.now().strftime("%d.%m.%Y")]
]
meta_table = Table(meta_data, colWidths=[4*cm, 10*cm])
meta_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 10),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
]))
story.append(meta_table)
story.append(Spacer(1, 0.5*cm))
# Statistics (if fairness data available)
if fairness_data and fairness_data.get('statistics'):
_add_statistics(story, styles, fairness_data['statistics'])
story.append(HRFlowable(width="100%", thickness=1, color=colors.HexColor('#e2e8f0')))
story.append(Spacer(1, 0.5*cm))
# Student grades table
sorted_students = sorted(students, key=lambda s: s.get('grade_points', 0), reverse=True)
_add_student_table(story, styles, sorted_students)
# Grade distribution
_add_grade_distribution(story, styles, sorted_students)
# Footer
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
def _add_statistics(story, styles, stats):
"""Add statistics section."""
story.append(Paragraph("Statistik", styles['SectionHeader']))
stats_data = [
["Durchschnitt:", f"{stats.get('average_grade', 0):.1f} Punkte"],
["Minimum:", f"{stats.get('min_grade', 0)} Punkte"],
["Maximum:", f"{stats.get('max_grade', 0)} Punkte"],
["Standardabweichung:", f"{stats.get('standard_deviation', 0):.2f}"],
]
stats_table = Table(stats_data, colWidths=[4*cm, 4*cm])
stats_table.setStyle(TableStyle([
('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('BACKGROUND', (0, 0), (-1, -1), colors.HexColor('#f7fafc')),
('BOX', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
]))
story.append(stats_table)
story.append(Spacer(1, 0.5*cm))
def _add_student_table(story, styles, sorted_students):
"""Add student grades table."""
story.append(Paragraph("Einzelergebnisse", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
table_data = [["#", "Name", "Rohpunkte", "Notenpunkte", "Note", "Status"]]
for idx, student in enumerate(sorted_students, 1):
grade_points = student.get('grade_points', 0)
grade_note = GRADE_POINTS_TO_NOTE.get(grade_points, "-")
raw_points = student.get('raw_points', 0)
status = student.get('status', 'unknown')
status_display = {
'completed': 'Abgeschlossen',
'first_examiner': 'In Korrektur',
'second_examiner': 'Zweitkorrektur',
'uploaded': 'Hochgeladen',
'ocr_complete': 'OCR fertig',
'analyzing': 'Wird analysiert'
}.get(status, status)
table_data.append([
str(idx),
student.get('student_name', 'Anonym'),
f"{raw_points}/100",
str(grade_points),
grade_note,
status_display
])
student_table = Table(table_data, colWidths=[1*cm, 5*cm, 2.5*cm, 3*cm, 2*cm, 3*cm])
student_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, 0), 9),
('ALIGN', (0, 0), (-1, 0), 'CENTER'),
('FONTSIZE', (0, 1), (-1, -1), 9),
('ALIGN', (0, 1), (0, -1), 'CENTER'),
('ALIGN', (2, 1), (4, -1), 'CENTER'),
('BOTTOMPADDING', (0, 0), (-1, -1), 6),
('TOPPADDING', (0, 0), (-1, -1), 6),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
('ROWBACKGROUNDS', (0, 1), (-1, -1), [colors.white, colors.HexColor('#f7fafc')]),
]))
story.append(student_table)
def _add_grade_distribution(story, styles, sorted_students):
"""Add grade distribution table."""
story.append(Spacer(1, 0.5*cm))
story.append(Paragraph("Notenverteilung", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
grade_counts = {}
for student in sorted_students:
gp = student.get('grade_points', 0)
grade_counts[gp] = grade_counts.get(gp, 0) + 1
dist_data = [["Punkte", "Note", "Anzahl"]]
for points in range(15, -1, -1):
if points in grade_counts:
note = GRADE_POINTS_TO_NOTE.get(points, "-")
count = grade_counts[points]
dist_data.append([str(points), note, str(count)])
if len(dist_data) > 1:
dist_table = Table(dist_data, colWidths=[2.5*cm, 2.5*cm, 2.5*cm])
dist_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.HexColor('#2c5282')),
('TEXTCOLOR', (0, 0), (-1, 0), colors.white),
('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
('FONTSIZE', (0, 0), (-1, -1), 9),
('ALIGN', (0, 0), (-1, -1), 'CENTER'),
('BOTTOMPADDING', (0, 0), (-1, -1), 4),
('TOPPADDING', (0, 0), (-1, -1), 4),
('GRID', (0, 0), (-1, -1), 0.5, colors.HexColor('#e2e8f0')),
]))
story.append(dist_table)
def generate_annotations_pdf(
student_data: Dict[str, Any],
klausur_data: Dict[str, Any],
annotations: List[Dict[str, Any]]
) -> bytes:
"""
Generate a PDF with all annotations for a student work.
Args:
student_data: Student work data
klausur_data: Klausur metadata
annotations: List of all annotations
Returns:
PDF as bytes
"""
buffer = io.BytesIO()
doc = SimpleDocTemplate(
buffer,
pagesize=A4,
rightMargin=2*cm,
leftMargin=2*cm,
topMargin=2*cm,
bottomMargin=2*cm
)
styles = get_custom_styles()
story = []
# Header
story.append(Paragraph("Anmerkungen zur Klausur", styles['GutachtenTitle']))
story.append(Paragraph(f"{student_data.get('student_name', 'Anonym')}", styles['GutachtenSubtitle']))
story.append(Spacer(1, 0.5*cm))
if not annotations:
story.append(Paragraph("<i>Keine Anmerkungen vorhanden.</i>", styles['GutachtenBody']))
else:
# Group by type
by_type = {}
for ann in annotations:
ann_type = ann.get('type', 'comment')
if ann_type not in by_type:
by_type[ann_type] = []
by_type[ann_type].append(ann)
for ann_type, anns in by_type.items():
type_name = CRITERIA_DISPLAY_NAMES.get(ann_type, ann_type.replace('_', ' ').title())
story.append(Paragraph(f"{type_name} ({len(anns)})", styles['SectionHeader']))
story.append(Spacer(1, 0.2*cm))
sorted_anns = sorted(anns, key=lambda a: (a.get('page', 0), a.get('position', {}).get('y', 0)))
for idx, ann in enumerate(sorted_anns, 1):
page = ann.get('page', 1)
text = ann.get('text', '')
suggestion = ann.get('suggestion', '')
severity = ann.get('severity', 'minor')
ann_text = f"<b>[S.{page}]</b> {text}"
if suggestion:
ann_text += f" -> <i>{suggestion}</i>"
if severity == 'critical':
ann_text = f"<font color='red'>{ann_text}</font>"
elif severity == 'major':
ann_text = f"<font color='orange'>{ann_text}</font>"
story.append(Paragraph(f"{idx}. {ann_text}", styles['ListItem']))
story.append(Spacer(1, 0.3*cm))
# Footer
story.append(Spacer(1, 1*cm))
story.append(HRFlowable(width="100%", thickness=0.5, color=colors.HexColor('#cbd5e0')))
story.append(Spacer(1, 0.2*cm))
story.append(Paragraph(
f"Erstellt am {datetime.now().strftime('%d.%m.%Y um %H:%M Uhr')} | BreakPilot Abiturkorrektur-System",
styles['MetaText']
))
# Build PDF
doc.build(story)
buffer.seek(0)
return buffer.getvalue()
# Backward-compat shim -- module moved to korrektur/pdf_export_overview.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.pdf_export_overview")
+4 -110
View File
@@ -1,110 +1,4 @@
"""
PDF Export - Constants and ReportLab styles for Abiturkorrektur PDFs.
"""
from reportlab.lib import colors
from reportlab.lib.enums import TA_LEFT, TA_CENTER, TA_JUSTIFY
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
# =============================================
# CONSTANTS
# =============================================
GRADE_POINTS_TO_NOTE = {
15: "1+", 14: "1", 13: "1-",
12: "2+", 11: "2", 10: "2-",
9: "3+", 8: "3", 7: "3-",
6: "4+", 5: "4", 4: "4-",
3: "5+", 2: "5", 1: "5-",
0: "6"
}
CRITERIA_DISPLAY_NAMES = {
"rechtschreibung": "Sprachliche Richtigkeit (Rechtschreibung)",
"grammatik": "Sprachliche Richtigkeit (Grammatik)",
"inhalt": "Inhaltliche Leistung",
"struktur": "Aufbau und Struktur",
"stil": "Ausdruck und Stil"
}
CRITERIA_WEIGHTS = {
"rechtschreibung": 15,
"grammatik": 15,
"inhalt": 40,
"struktur": 15,
"stil": 15
}
# =============================================
# STYLES
# =============================================
def get_custom_styles():
"""Create custom paragraph styles for Gutachten."""
styles = getSampleStyleSheet()
# Title style
styles.add(ParagraphStyle(
name='GutachtenTitle',
parent=styles['Heading1'],
fontSize=16,
spaceAfter=12,
alignment=TA_CENTER,
textColor=colors.HexColor('#1e3a5f')
))
# Subtitle style
styles.add(ParagraphStyle(
name='GutachtenSubtitle',
parent=styles['Heading2'],
fontSize=12,
spaceAfter=8,
spaceBefore=16,
textColor=colors.HexColor('#2c5282')
))
# Section header
styles.add(ParagraphStyle(
name='SectionHeader',
parent=styles['Heading3'],
fontSize=11,
spaceAfter=6,
spaceBefore=12,
textColor=colors.HexColor('#2d3748'),
borderColor=colors.HexColor('#e2e8f0'),
borderWidth=0,
borderPadding=0
))
# Body text
styles.add(ParagraphStyle(
name='GutachtenBody',
parent=styles['Normal'],
fontSize=10,
leading=14,
alignment=TA_JUSTIFY,
spaceAfter=6
))
# Small text for footer/meta
styles.add(ParagraphStyle(
name='MetaText',
parent=styles['Normal'],
fontSize=8,
textColor=colors.grey,
alignment=TA_LEFT
))
# List item
styles.add(ParagraphStyle(
name='ListItem',
parent=styles['Normal'],
fontSize=10,
leftIndent=20,
bulletIndent=10,
spaceAfter=4
))
return styles
# Backward-compat shim -- module moved to korrektur/pdf_export_styles.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.pdf_export_styles")
+4 -164
View File
@@ -1,164 +1,4 @@
"""
PDF Extraction Module
NOTE: This module delegates ML-heavy operations to the embedding-service via HTTP.
Provides enhanced PDF text extraction using multiple backends (in embedding-service):
1. Unstructured.io - Best for complex layouts, tables, headers (Apache 2.0)
2. pypdf - Modern, BSD-licensed PDF library (recommended default)
License Compliance:
- Default backends (unstructured, pypdf) are BSD/Apache licensed
"""
import os
import logging
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# Configuration (for backward compatibility - actual config in embedding-service)
EMBEDDING_SERVICE_URL = os.getenv("EMBEDDING_SERVICE_URL", "http://embedding-service:8087")
PDF_BACKEND = os.getenv("PDF_EXTRACTION_BACKEND", "auto")
class PDFExtractionError(Exception):
"""Error during PDF extraction."""
pass
class PDFExtractionResult:
"""Result of PDF extraction with metadata."""
def __init__(
self,
text: str,
backend_used: str,
pages: int = 0,
elements: Optional[List[Dict]] = None,
tables: Optional[List[Dict]] = None,
metadata: Optional[Dict] = None,
):
self.text = text
self.backend_used = backend_used
self.pages = pages
self.elements = elements or []
self.tables = tables or []
self.metadata = metadata or {}
def to_dict(self) -> Dict:
return {
"text": self.text,
"backend_used": self.backend_used,
"pages": self.pages,
"element_count": len(self.elements),
"table_count": len(self.tables),
"metadata": self.metadata,
}
def _detect_available_backends() -> List[str]:
"""Get available backends from embedding-service."""
import httpx
try:
with httpx.Client(timeout=5.0) as client:
response = client.get(f"{EMBEDDING_SERVICE_URL}/models")
if response.status_code == 200:
data = response.json()
return data.get("available_pdf_backends", ["pypdf"])
except Exception as e:
logger.warning(f"Could not reach embedding-service: {e}")
return []
def extract_text_from_pdf_enhanced(
pdf_content: bytes,
backend: str = PDF_BACKEND,
fallback: bool = True,
) -> PDFExtractionResult:
"""
Extract text from PDF using embedding-service.
Args:
pdf_content: PDF file content as bytes
backend: Preferred backend (auto, unstructured, pypdf)
fallback: If True, try other backends if preferred fails
Returns:
PDFExtractionResult with extracted text and metadata
"""
import httpx
try:
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{EMBEDDING_SERVICE_URL}/extract-pdf",
content=pdf_content,
headers={"Content-Type": "application/octet-stream"}
)
response.raise_for_status()
data = response.json()
return PDFExtractionResult(
text=data.get("text", ""),
backend_used=data.get("backend_used", "unknown"),
pages=data.get("pages", 0),
tables=[{"count": data.get("table_count", 0)}] if data.get("table_count", 0) > 0 else [],
metadata={"embedding_service": True}
)
except httpx.TimeoutException:
raise PDFExtractionError("PDF extraction timeout")
except httpx.HTTPStatusError as e:
raise PDFExtractionError(f"PDF extraction error: {e.response.status_code}")
except Exception as e:
raise PDFExtractionError(f"Failed to extract PDF: {str(e)}")
def extract_text_from_pdf(pdf_content: bytes) -> str:
"""
Extract text from PDF (simple interface).
This is a drop-in replacement for the original function
that uses the embedding-service internally.
"""
result = extract_text_from_pdf_enhanced(pdf_content)
return result.text
def get_pdf_extraction_info() -> dict:
"""Get information about PDF extraction configuration."""
import httpx
try:
with httpx.Client(timeout=5.0) as client:
response = client.get(f"{EMBEDDING_SERVICE_URL}/models")
if response.status_code == 200:
data = response.json()
available = data.get("available_pdf_backends", [])
return {
"configured_backend": data.get("pdf_backend", PDF_BACKEND),
"available_backends": available,
"recommended": "unstructured" if "unstructured" in available else "pypdf",
"backend_licenses": {
"unstructured": "Apache-2.0",
"pypdf": "BSD-3-Clause",
},
"commercial_safe_backends": available,
"embedding_service_url": EMBEDDING_SERVICE_URL,
"embedding_service_available": True,
}
except Exception as e:
logger.warning(f"Could not reach embedding-service: {e}")
# Fallback when embedding-service is not available
return {
"configured_backend": PDF_BACKEND,
"available_backends": [],
"recommended": None,
"backend_licenses": {},
"commercial_safe_backends": [],
"embedding_service_url": EMBEDDING_SERVICE_URL,
"embedding_service_available": False,
}
# Backward-compat shim -- module moved to korrektur/pdf_extraction.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("korrektur.pdf_extraction")
+4 -38
View File
@@ -1,38 +1,4 @@
"""
RBAC/ABAC Policy System for Klausur-Service (barrel re-export)
This module was split into:
- rbac_types.py (Enums, data structures)
- rbac_permissions.py (Permission matrix)
- rbac_engine.py (PolicyEngine, default policies, API guards)
All public symbols are re-exported here for backwards compatibility.
"""
# Types and enums
from rbac_types import ( # noqa: F401
Role,
Action,
ResourceType,
ZKVisibilityMode,
EHVisibilityMode,
VerfahrenType,
PolicySet,
RoleAssignment,
KeyShare,
Tenant,
Namespace,
ExamPackage,
)
# Permission matrix
from rbac_permissions import DEFAULT_PERMISSIONS # noqa: F401
# Engine, policies, guards
from rbac_engine import ( # noqa: F401
PolicyEngine,
create_default_policy_sets,
get_policy_engine,
require_permission,
require_role,
)
# Backward-compat shim -- module moved to compliance/rbac.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("compliance.rbac")
+4 -498
View File
@@ -1,498 +1,4 @@
"""
RBAC Policy Engine
Core engine for RBAC/ABAC permission checks,
role assignments, key shares, and default policies.
Extracted from rbac.py for file-size compliance.
"""
from typing import Optional, List, Dict, Set
from datetime import datetime, timezone
import uuid
from functools import wraps
from fastapi import HTTPException, Request
from rbac_types import (
Role,
Action,
ResourceType,
ZKVisibilityMode,
PolicySet,
RoleAssignment,
KeyShare,
)
from rbac_permissions import DEFAULT_PERMISSIONS
# =============================================
# POLICY ENGINE
# =============================================
class PolicyEngine:
"""
Engine fuer RBAC/ABAC Entscheidungen.
Prueft:
1. Basis-Rollenberechtigung (RBAC)
2. Policy-Einschraenkungen (ABAC)
3. Key Share Berechtigungen
"""
def __init__(self):
self.policy_sets: Dict[str, PolicySet] = {}
self.role_assignments: Dict[str, List[RoleAssignment]] = {} # user_id -> assignments
self.key_shares: Dict[str, List[KeyShare]] = {} # user_id -> shares
def register_policy_set(self, policy: PolicySet):
"""Registriere ein Policy Set."""
self.policy_sets[policy.id] = policy
def get_policy_for_context(
self,
bundesland: str,
jahr: int,
fach: Optional[str] = None,
verfahren: str = "abitur"
) -> Optional[PolicySet]:
"""Finde das passende Policy Set fuer einen Kontext."""
# Exakte Uebereinstimmung
for policy in self.policy_sets.values():
if (policy.bundesland == bundesland and
policy.jahr == jahr and
policy.verfahren == verfahren):
if policy.fach is None or policy.fach == fach:
return policy
# Fallback: Default Policy
for policy in self.policy_sets.values():
if policy.bundesland == "DEFAULT":
return policy
return None
def assign_role(
self,
user_id: str,
role: Role,
resource_type: ResourceType,
resource_id: str,
granted_by: str,
tenant_id: Optional[str] = None,
namespace_id: Optional[str] = None,
valid_to: Optional[datetime] = None
) -> RoleAssignment:
"""Weise einem User eine Rolle zu."""
assignment = RoleAssignment(
id=str(uuid.uuid4()),
user_id=user_id,
role=role,
resource_type=resource_type,
resource_id=resource_id,
tenant_id=tenant_id,
namespace_id=namespace_id,
granted_by=granted_by,
valid_to=valid_to
)
if user_id not in self.role_assignments:
self.role_assignments[user_id] = []
self.role_assignments[user_id].append(assignment)
return assignment
def revoke_role(self, assignment_id: str, revoked_by: str) -> bool:
"""Widerrufe eine Rollenzuweisung."""
for user_assignments in self.role_assignments.values():
for assignment in user_assignments:
if assignment.id == assignment_id:
assignment.revoked_at = datetime.now(timezone.utc)
return True
return False
def get_user_roles(
self,
user_id: str,
resource_type: Optional[ResourceType] = None,
resource_id: Optional[str] = None
) -> List[Role]:
"""Hole alle aktiven Rollen eines Users."""
assignments = self.role_assignments.get(user_id, [])
roles = []
for assignment in assignments:
if not assignment.is_active():
continue
if resource_type and assignment.resource_type != resource_type:
continue
if resource_id and assignment.resource_id != resource_id:
continue
roles.append(assignment.role)
return list(set(roles))
def create_key_share(
self,
user_id: str,
package_id: str,
permissions: Set[str],
granted_by: str,
scope: str = "full",
invite_token: Optional[str] = None
) -> KeyShare:
"""Erstelle einen Key Share."""
share = KeyShare(
id=str(uuid.uuid4()),
user_id=user_id,
package_id=package_id,
permissions=permissions,
scope=scope,
granted_by=granted_by,
invite_token=invite_token
)
if user_id not in self.key_shares:
self.key_shares[user_id] = []
self.key_shares[user_id].append(share)
return share
def accept_key_share(self, share_id: str, token: str) -> bool:
"""Akzeptiere einen Key Share via Invite Token."""
for user_shares in self.key_shares.values():
for share in user_shares:
if share.id == share_id and share.invite_token == token:
share.accepted_at = datetime.now(timezone.utc)
return True
return False
def revoke_key_share(self, share_id: str, revoked_by: str) -> bool:
"""Widerrufe einen Key Share."""
for user_shares in self.key_shares.values():
for share in user_shares:
if share.id == share_id:
share.revoked_at = datetime.now(timezone.utc)
share.revoked_by = revoked_by
return True
return False
def check_permission(
self,
user_id: str,
action: Action,
resource_type: ResourceType,
resource_id: str,
policy: Optional[PolicySet] = None,
package_id: Optional[str] = None
) -> bool:
"""
Pruefe ob ein User eine Aktion ausfuehren darf.
Prueft:
1. Basis-RBAC
2. Policy-Einschraenkungen
3. Key Share (falls package_id angegeben)
"""
# 1. Hole aktive Rollen
roles = self.get_user_roles(user_id, resource_type, resource_id)
if not roles:
return False
# 2. Pruefe Basis-RBAC
has_permission = False
for role in roles:
role_permissions = DEFAULT_PERMISSIONS.get(role, {})
resource_permissions = role_permissions.get(resource_type, set())
if action in resource_permissions:
has_permission = True
break
if not has_permission:
return False
# 3. Pruefe Policy-Einschraenkungen
if policy:
# ZK Visibility Mode
if Role.ZWEITKORREKTOR in roles:
if policy.zk_visibility_mode == ZKVisibilityMode.BLIND:
# Blind: ZK darf EK-Outputs nicht sehen
if resource_type in [ResourceType.EVALUATION, ResourceType.REPORT, ResourceType.GRADE_DECISION]:
if action == Action.READ:
# Pruefe ob es EK-Outputs sind (muesste ueber Metadaten geprueft werden)
pass # Implementierung abhaengig von Datenmodell
elif policy.zk_visibility_mode == ZKVisibilityMode.SEMI:
# Semi: ZK sieht Annotationen, aber keine Note
if resource_type == ResourceType.GRADE_DECISION and action == Action.READ:
return False
# 4. Pruefe Key Share (falls Package-basiert)
if package_id:
user_shares = self.key_shares.get(user_id, [])
has_key_share = any(
share.package_id == package_id and share.is_active()
for share in user_shares
)
if not has_key_share:
return False
return True
def get_allowed_actions(
self,
user_id: str,
resource_type: ResourceType,
resource_id: str,
policy: Optional[PolicySet] = None
) -> Set[Action]:
"""Hole alle erlaubten Aktionen fuer einen User auf einer Ressource."""
roles = self.get_user_roles(user_id, resource_type, resource_id)
allowed = set()
for role in roles:
role_permissions = DEFAULT_PERMISSIONS.get(role, {})
resource_permissions = role_permissions.get(resource_type, set())
allowed.update(resource_permissions)
# Policy-Einschraenkungen anwenden
if policy and Role.ZWEITKORREKTOR in roles:
if policy.zk_visibility_mode == ZKVisibilityMode.BLIND:
# Entferne READ fuer bestimmte Ressourcen
pass # Detailimplementierung
return allowed
# =============================================
# DEFAULT POLICY SETS (alle Bundeslaender)
# =============================================
def create_default_policy_sets() -> List[PolicySet]:
"""
Erstelle Default Policy Sets fuer alle Bundeslaender.
Diese koennen spaeter pro Land verfeinert werden.
"""
bundeslaender = [
"baden-wuerttemberg", "bayern", "berlin", "brandenburg",
"bremen", "hamburg", "hessen", "mecklenburg-vorpommern",
"niedersachsen", "nordrhein-westfalen", "rheinland-pfalz",
"saarland", "sachsen", "sachsen-anhalt", "schleswig-holstein",
"thueringen"
]
policies = []
# Default Policy (Fallback)
policies.append(PolicySet(
id="DEFAULT-2025",
bundesland="DEFAULT",
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.FULL,
eh_visibility_mode=PolicySet.__dataclass_fields__["eh_visibility_mode"].default,
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz"
))
# Niedersachsen (Beispiel mit spezifischen Anpassungen)
policies.append(PolicySet(
id="NI-2025-ABITUR",
bundesland="niedersachsen",
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.FULL, # In NI sieht ZK alles
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz",
export_template_id="niedersachsen-abitur"
))
# Bayern (Beispiel mit SEMI visibility)
policies.append(PolicySet(
id="BY-2025-ABITUR",
bundesland="bayern",
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.SEMI, # ZK sieht Annotationen, nicht Note
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz",
export_template_id="bayern-abitur"
))
# NRW (Beispiel)
policies.append(PolicySet(
id="NW-2025-ABITUR",
bundesland="nordrhein-westfalen",
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.FULL,
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz",
export_template_id="nrw-abitur"
))
# Generiere Basis-Policies fuer alle anderen Bundeslaender
for bl in bundeslaender:
if bl not in ["niedersachsen", "bayern", "nordrhein-westfalen"]:
policies.append(PolicySet(
id=f"{bl[:2].upper()}-2025-ABITUR",
bundesland=bl,
jahr=2025,
fach=None,
verfahren="abitur",
zk_visibility_mode=ZKVisibilityMode.FULL,
allow_teacher_uploaded_eh=True,
allow_land_uploaded_eh=True,
require_rights_confirmation_on_upload=True,
third_correction_threshold=4,
final_signoff_role="fachvorsitz"
))
return policies
# =============================================
# GLOBAL POLICY ENGINE INSTANCE
# =============================================
# Singleton Policy Engine
_policy_engine: Optional[PolicyEngine] = None
def get_policy_engine() -> PolicyEngine:
"""Hole die globale Policy Engine Instanz."""
global _policy_engine
if _policy_engine is None:
_policy_engine = PolicyEngine()
# Registriere Default Policies
for policy in create_default_policy_sets():
_policy_engine.register_policy_set(policy)
return _policy_engine
# =============================================
# API GUARDS (Decorators fuer FastAPI)
# =============================================
def require_permission(
action: Action,
resource_type: ResourceType,
resource_id_param: str = "resource_id"
):
"""
Decorator fuer FastAPI Endpoints.
Prueft ob der aktuelle User die angegebene Berechtigung hat.
Usage:
@app.get("/api/v1/packages/{package_id}")
@require_permission(Action.READ, ResourceType.EXAM_PACKAGE, "package_id")
async def get_package(package_id: str, request: Request):
...
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
request = kwargs.get('request')
if not request:
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request:
raise HTTPException(status_code=500, detail="Request not found")
# User aus Token holen
user = getattr(request.state, 'user', None)
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user.get('user_id')
resource_id = kwargs.get(resource_id_param)
# Policy Engine pruefen
engine = get_policy_engine()
# Optional: Policy aus Kontext laden
policy = None
bundesland = user.get('bundesland')
if bundesland:
policy = engine.get_policy_for_context(bundesland, 2025)
if not engine.check_permission(
user_id=user_id,
action=action,
resource_type=resource_type,
resource_id=resource_id,
policy=policy
):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {action.value} on {resource_type.value}"
)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_role(role: Role):
"""
Decorator der prueft ob User eine bestimmte Rolle hat.
Usage:
@app.post("/api/v1/eh/publish")
@require_role(Role.LAND_ADMIN)
async def publish_eh(request: Request):
...
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
request = kwargs.get('request')
if not request:
for arg in args:
if isinstance(arg, Request):
request = arg
break
if not request:
raise HTTPException(status_code=500, detail="Request not found")
user = getattr(request.state, 'user', None)
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user.get('user_id')
engine = get_policy_engine()
user_roles = engine.get_user_roles(user_id)
if role not in user_roles:
raise HTTPException(
status_code=403,
detail=f"Role required: {role.value}"
)
return await func(*args, **kwargs)
return wrapper
return decorator
# Backward-compat shim -- module moved to compliance/rbac_engine.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("compliance.rbac_engine")
+4 -221
View File
@@ -1,221 +1,4 @@
"""
RBAC Permission Matrix
Default role-to-resource permission mappings for
Klausur-Korrektur and Zeugnis workflows.
Extracted from rbac.py for file-size compliance.
"""
from typing import Dict, Set
from rbac_types import Role, Action, ResourceType
# =============================================
# RBAC PERMISSION MATRIX
# =============================================
# Standard-Berechtigungsmatrix (kann durch Policies ueberschrieben werden)
DEFAULT_PERMISSIONS: Dict[Role, Dict[ResourceType, Set[Action]]] = {
# Erstkorrektor
Role.ERSTKORREKTOR: {
ResourceType.EXAM_PACKAGE: {Action.READ, Action.UPDATE, Action.SHARE_KEY, Action.LOCK},
ResourceType.STUDENT_WORK: {Action.READ, Action.UPDATE},
ResourceType.EH_DOCUMENT: {Action.READ, Action.UPLOAD, Action.UPDATE},
ResourceType.RUBRIC: {Action.READ, Action.UPDATE},
ResourceType.ANNOTATION: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.EVALUATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.REPORT: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.GRADE_DECISION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Zweitkorrektor (Standard: FULL visibility)
Role.ZWEITKORREKTOR: {
ResourceType.EXAM_PACKAGE: {Action.READ},
ResourceType.STUDENT_WORK: {Action.READ, Action.UPDATE},
ResourceType.EH_DOCUMENT: {Action.READ},
ResourceType.RUBRIC: {Action.READ},
ResourceType.ANNOTATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.EVALUATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.REPORT: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.GRADE_DECISION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.EXPORT: {Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Drittkorrektor
Role.DRITTKORREKTOR: {
ResourceType.EXAM_PACKAGE: {Action.READ},
ResourceType.STUDENT_WORK: {Action.READ, Action.UPDATE},
ResourceType.EH_DOCUMENT: {Action.READ},
ResourceType.RUBRIC: {Action.READ},
ResourceType.ANNOTATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.EVALUATION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.REPORT: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.GRADE_DECISION: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Fachvorsitz
Role.FACHVORSITZ: {
ResourceType.TENANT: {Action.READ},
ResourceType.NAMESPACE: {Action.READ, Action.UPDATE},
ResourceType.EXAM_PACKAGE: {Action.READ, Action.UPDATE, Action.LOCK, Action.UNLOCK, Action.SIGN_OFF},
ResourceType.STUDENT_WORK: {Action.READ, Action.UPDATE},
ResourceType.EH_DOCUMENT: {Action.READ, Action.UPLOAD, Action.UPDATE},
ResourceType.RUBRIC: {Action.READ, Action.UPDATE},
ResourceType.ANNOTATION: {Action.READ, Action.UPDATE},
ResourceType.EVALUATION: {Action.READ, Action.UPDATE},
ResourceType.REPORT: {Action.READ, Action.UPDATE},
ResourceType.GRADE_DECISION: {Action.READ, Action.UPDATE, Action.SIGN_OFF},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Pruefungsvorsitz
Role.PRUEFUNGSVORSITZ: {
ResourceType.TENANT: {Action.READ},
ResourceType.NAMESPACE: {Action.READ, Action.CREATE},
ResourceType.EXAM_PACKAGE: {Action.READ, Action.SIGN_OFF},
ResourceType.STUDENT_WORK: {Action.READ},
ResourceType.EH_DOCUMENT: {Action.READ},
ResourceType.GRADE_DECISION: {Action.READ, Action.SIGN_OFF},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Schul-Admin
Role.SCHUL_ADMIN: {
ResourceType.TENANT: {Action.READ, Action.UPDATE},
ResourceType.NAMESPACE: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.EXAM_PACKAGE: {Action.CREATE, Action.READ, Action.DELETE, Action.ASSIGN_ROLE},
ResourceType.EH_DOCUMENT: {Action.READ, Action.UPLOAD, Action.DELETE},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Land-Admin (Behoerde)
Role.LAND_ADMIN: {
ResourceType.TENANT: {Action.READ},
ResourceType.EH_DOCUMENT: {Action.READ, Action.UPLOAD, Action.UPDATE, Action.DELETE, Action.PUBLISH_OFFICIAL},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Auditor
Role.AUDITOR: {
ResourceType.AUDIT_LOG: {Action.READ},
ResourceType.EXAM_PACKAGE: {Action.READ}, # Nur Metadaten
# Kein Zugriff auf Inhalte!
},
# Operator
Role.OPERATOR: {
ResourceType.TENANT: {Action.READ},
ResourceType.NAMESPACE: {Action.READ},
ResourceType.EXAM_PACKAGE: {Action.READ}, # Nur Metadaten
ResourceType.AUDIT_LOG: {Action.READ},
# Break-glass separat gehandhabt
},
# Teacher Assistant
Role.TEACHER_ASSISTANT: {
ResourceType.STUDENT_WORK: {Action.READ},
ResourceType.ANNOTATION: {Action.CREATE, Action.READ}, # Nur bestimmte Typen
ResourceType.EH_DOCUMENT: {Action.READ},
},
# Exam Author (nur Vorabi)
Role.EXAM_AUTHOR: {
ResourceType.EH_DOCUMENT: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.RUBRIC: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
},
# =============================================
# ZEUGNIS-WORKFLOW ROLLEN
# =============================================
# Klassenlehrer - Erstellt Zeugnisse, Kopfnoten, Bemerkungen
Role.KLASSENLEHRER: {
ResourceType.NAMESPACE: {Action.READ},
ResourceType.ZEUGNIS: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_ENTWURF: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.ZEUGNIS_VORLAGE: {Action.READ},
ResourceType.SCHUELER_DATEN: {Action.READ, Action.UPDATE},
ResourceType.FACHNOTE: {Action.READ}, # Liest Fachnoten der Fachlehrer
ResourceType.KOPFNOTE: {Action.CREATE, Action.READ, Action.UPDATE},
ResourceType.FEHLZEITEN: {Action.READ, Action.UPDATE},
ResourceType.BEMERKUNG: {Action.CREATE, Action.READ, Action.UPDATE, Action.DELETE},
ResourceType.VERSETZUNG: {Action.READ},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Fachlehrer - Traegt Fachnoten ein
Role.FACHLEHRER: {
ResourceType.NAMESPACE: {Action.READ},
ResourceType.SCHUELER_DATEN: {Action.READ}, # Nur eigene Schueler
ResourceType.FACHNOTE: {Action.CREATE, Action.READ, Action.UPDATE}, # Nur eigenes Fach
ResourceType.BEMERKUNG: {Action.CREATE, Action.READ}, # Fachbezogene Bemerkungen
ResourceType.AUDIT_LOG: {Action.READ},
},
# Zeugnisbeauftragter - Qualitaetskontrolle
Role.ZEUGNISBEAUFTRAGTER: {
ResourceType.NAMESPACE: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_ENTWURF: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_VORLAGE: {Action.READ, Action.UPDATE, Action.UPLOAD},
ResourceType.SCHUELER_DATEN: {Action.READ},
ResourceType.FACHNOTE: {Action.READ},
ResourceType.KOPFNOTE: {Action.READ, Action.UPDATE},
ResourceType.FEHLZEITEN: {Action.READ},
ResourceType.BEMERKUNG: {Action.READ, Action.UPDATE},
ResourceType.VERSETZUNG: {Action.READ},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Sekretariat - Druck, Versand, Archivierung
Role.SEKRETARIAT: {
ResourceType.ZEUGNIS: {Action.READ, Action.DOWNLOAD},
ResourceType.ZEUGNIS_VORLAGE: {Action.READ},
ResourceType.SCHUELER_DATEN: {Action.READ}, # Fuer Adressdaten
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Schulleitung - Finale Zeugnis-Freigabe
Role.SCHULLEITUNG: {
ResourceType.TENANT: {Action.READ},
ResourceType.NAMESPACE: {Action.READ, Action.CREATE},
ResourceType.ZEUGNIS: {Action.READ, Action.SIGN_OFF, Action.LOCK},
ResourceType.ZEUGNIS_ENTWURF: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_VORLAGE: {Action.READ, Action.UPDATE},
ResourceType.SCHUELER_DATEN: {Action.READ},
ResourceType.FACHNOTE: {Action.READ},
ResourceType.KOPFNOTE: {Action.READ, Action.UPDATE},
ResourceType.FEHLZEITEN: {Action.READ},
ResourceType.BEMERKUNG: {Action.READ, Action.UPDATE},
ResourceType.KONFERENZ_BESCHLUSS: {Action.CREATE, Action.READ, Action.UPDATE, Action.SIGN_OFF},
ResourceType.VERSETZUNG: {Action.CREATE, Action.READ, Action.UPDATE, Action.SIGN_OFF},
ResourceType.EXPORT: {Action.CREATE, Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
# Stufenleitung - Stufenkoordination (z.B. Oberstufe)
Role.STUFENLEITUNG: {
ResourceType.NAMESPACE: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS: {Action.READ, Action.UPDATE},
ResourceType.ZEUGNIS_ENTWURF: {Action.READ, Action.UPDATE},
ResourceType.SCHUELER_DATEN: {Action.READ},
ResourceType.FACHNOTE: {Action.READ},
ResourceType.KOPFNOTE: {Action.READ},
ResourceType.FEHLZEITEN: {Action.READ},
ResourceType.BEMERKUNG: {Action.READ, Action.UPDATE},
ResourceType.KONFERENZ_BESCHLUSS: {Action.READ},
ResourceType.VERSETZUNG: {Action.READ, Action.UPDATE},
ResourceType.EXPORT: {Action.READ, Action.DOWNLOAD},
ResourceType.AUDIT_LOG: {Action.READ},
},
}
# Backward-compat shim -- module moved to compliance/rbac_permissions.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("compliance.rbac_permissions")
+4 -438
View File
@@ -1,438 +1,4 @@
"""
RBAC/ABAC Type Definitions
Enums, data structures, and models for the policy system.
Extracted from rbac.py for file-size compliance.
"""
import json
from enum import Enum
from dataclasses import dataclass, field, asdict
from typing import Optional, List, Dict, Set, Any
from datetime import datetime, timezone
import uuid
# =============================================
# ENUMS: Roles, Actions, Resources
# =============================================
class Role(str, Enum):
"""Fachliche Rollen in Korrektur- und Zeugniskette."""
# === Klausur-Korrekturkette ===
ERSTKORREKTOR = "erstkorrektor" # EK
ZWEITKORREKTOR = "zweitkorrektor" # ZK
DRITTKORREKTOR = "drittkorrektor" # DK
# === Zeugnis-Workflow ===
KLASSENLEHRER = "klassenlehrer" # KL - Erstellt Zeugnis, Kopfnoten, Bemerkungen
FACHLEHRER = "fachlehrer" # FL - Traegt Fachnoten ein
ZEUGNISBEAUFTRAGTER = "zeugnisbeauftragter" # ZB - Qualitaetskontrolle
SEKRETARIAT = "sekretariat" # SEK - Druck, Versand, Archivierung
# === Leitung (Klausur + Zeugnis) ===
FACHVORSITZ = "fachvorsitz" # FVL - Fachpruefungsleitung
PRUEFUNGSVORSITZ = "pruefungsvorsitz" # PV - Schulleitung / Pruefungsvorsitz
SCHULLEITUNG = "schulleitung" # SL - Finale Zeugnis-Freigabe
STUFENLEITUNG = "stufenleitung" # STL - Stufenkoordination
# === Administration ===
SCHUL_ADMIN = "schul_admin" # SA
LAND_ADMIN = "land_admin" # LA - Behoerde
# === Spezial ===
AUDITOR = "auditor" # DSB/Auditor
OPERATOR = "operator" # OPS - Support
TEACHER_ASSISTANT = "teacher_assistant" # TA - Referendar
EXAM_AUTHOR = "exam_author" # EA - nur Vorabi
class Action(str, Enum):
"""Moegliche Operationen auf Ressourcen."""
CREATE = "create"
READ = "read"
UPDATE = "update"
DELETE = "delete"
ASSIGN_ROLE = "assign_role"
INVITE_USER = "invite_user"
REMOVE_USER = "remove_user"
UPLOAD = "upload"
DOWNLOAD = "download"
LOCK = "lock" # Finalisieren
UNLOCK = "unlock" # Nur mit Sonderrecht
SIGN_OFF = "sign_off" # Freigabe
SHARE_KEY = "share_key" # Key Share erzeugen
VIEW_PII = "view_pii" # Falls PII vorhanden
BREAK_GLASS = "break_glass" # Notfallzugriff
PUBLISH_OFFICIAL = "publish_official" # Amtliche EH verteilen
class ResourceType(str, Enum):
"""Ressourcentypen im System."""
TENANT = "tenant"
NAMESPACE = "namespace"
# === Klausur-Korrektur ===
EXAM_PACKAGE = "exam_package"
STUDENT_WORK = "student_work"
EH_DOCUMENT = "eh_document"
RUBRIC = "rubric" # Punkteraster
ANNOTATION = "annotation"
EVALUATION = "evaluation" # Kriterien/Punkte
REPORT = "report" # Gutachten
GRADE_DECISION = "grade_decision"
# === Zeugnisgenerator ===
ZEUGNIS = "zeugnis" # Zeugnisdokument
ZEUGNIS_VORLAGE = "zeugnis_vorlage" # Zeugnisvorlage/Template
ZEUGNIS_ENTWURF = "zeugnis_entwurf" # Zeugnisentwurf (vor Freigabe)
SCHUELER_DATEN = "schueler_daten" # Schueler-Stammdaten, Noten
FACHNOTE = "fachnote" # Einzelne Fachnote
KOPFNOTE = "kopfnote" # Arbeits-/Sozialverhalten
FEHLZEITEN = "fehlzeiten" # Fehlzeiten
BEMERKUNG = "bemerkung" # Zeugnisbemerkungen
KONFERENZ_BESCHLUSS = "konferenz_beschluss" # Konferenzergebnis
VERSETZUNG = "versetzung" # Versetzungsentscheidung
# === Allgemein ===
DOCUMENT = "document" # Generischer Dokumenttyp (EH, Vorlagen, etc.)
TEMPLATE = "template" # Generische Vorlagen
EXPORT = "export"
AUDIT_LOG = "audit_log"
KEY_MATERIAL = "key_material"
class ZKVisibilityMode(str, Enum):
"""Sichtbarkeitsmodus fuer Zweitkorrektoren."""
BLIND = "blind" # ZK sieht keine EK-Note/Gutachten
SEMI = "semi" # ZK sieht Annotationen, aber keine Note
FULL = "full" # ZK sieht alles
class EHVisibilityMode(str, Enum):
"""Sichtbarkeitsmodus fuer Erwartungshorizonte."""
BLIND = "blind" # ZK sieht EH nicht (selten)
SHARED = "shared" # ZK sieht EH (Standard)
class VerfahrenType(str, Enum):
"""Verfahrenstypen fuer Klausuren und Zeugnisse."""
# === Klausur/Pruefungsverfahren ===
ABITUR = "abitur"
VORABITUR = "vorabitur"
KLAUSUR = "klausur"
NACHPRUEFUNG = "nachpruefung"
# === Zeugnisverfahren ===
HALBJAHRESZEUGNIS = "halbjahreszeugnis"
JAHRESZEUGNIS = "jahreszeugnis"
ABSCHLUSSZEUGNIS = "abschlusszeugnis"
ABGANGSZEUGNIS = "abgangszeugnis"
@classmethod
def is_exam_type(cls, verfahren: str) -> bool:
"""Pruefe ob Verfahren ein Pruefungstyp ist."""
exam_types = {cls.ABITUR, cls.VORABITUR, cls.KLAUSUR, cls.NACHPRUEFUNG}
try:
return cls(verfahren) in exam_types
except ValueError:
return False
@classmethod
def is_certificate_type(cls, verfahren: str) -> bool:
"""Pruefe ob Verfahren ein Zeugnistyp ist."""
cert_types = {cls.HALBJAHRESZEUGNIS, cls.JAHRESZEUGNIS, cls.ABSCHLUSSZEUGNIS, cls.ABGANGSZEUGNIS}
try:
return cls(verfahren) in cert_types
except ValueError:
return False
# =============================================
# DATA STRUCTURES
# =============================================
@dataclass
class PolicySet:
"""
Policy-Konfiguration pro Bundesland/Jahr/Fach.
Ermoeglicht bundesland-spezifische Unterschiede ohne
harte Codierung im Quellcode.
Unterstuetzte Verfahrenstypen:
- Pruefungen: abitur, vorabitur, klausur, nachpruefung
- Zeugnisse: halbjahreszeugnis, jahreszeugnis, abschlusszeugnis, abgangszeugnis
"""
id: str
bundesland: str
jahr: int
fach: Optional[str] # None = gilt fuer alle Faecher
verfahren: str # See VerfahrenType enum
# Sichtbarkeitsregeln (Klausur)
zk_visibility_mode: ZKVisibilityMode = ZKVisibilityMode.FULL
eh_visibility_mode: EHVisibilityMode = EHVisibilityMode.SHARED
# EH-Quellen (Klausur)
allow_teacher_uploaded_eh: bool = True
allow_land_uploaded_eh: bool = True
require_rights_confirmation_on_upload: bool = True
require_dual_control_for_official_eh_update: bool = False
# Korrekturregeln (Klausur)
third_correction_threshold: int = 4 # Notenpunkte Abweichung
final_signoff_role: str = "fachvorsitz"
# Zeugnisregeln (Zeugnis)
require_klassenlehrer_approval: bool = True
require_schulleitung_signoff: bool = True
allow_sekretariat_edit_after_approval: bool = False
konferenz_protokoll_required: bool = True
bemerkungen_require_review: bool = True
fehlzeiten_auto_import: bool = True
kopfnoten_enabled: bool = False
versetzung_auto_calculate: bool = True
# Export & Anzeige
quote_verbatim_allowed: bool = False # Amtliche Texte in UI
export_template_id: str = "default"
# Zusaetzliche Flags
flags: Dict[str, Any] = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def is_exam_policy(self) -> bool:
"""Pruefe ob diese Policy fuer Pruefungen ist."""
return VerfahrenType.is_exam_type(self.verfahren)
def is_certificate_policy(self) -> bool:
"""Pruefe ob diese Policy fuer Zeugnisse ist."""
return VerfahrenType.is_certificate_type(self.verfahren)
def to_dict(self):
d = asdict(self)
d['zk_visibility_mode'] = self.zk_visibility_mode.value
d['eh_visibility_mode'] = self.eh_visibility_mode.value
d['created_at'] = self.created_at.isoformat()
return d
@dataclass
class RoleAssignment:
"""
Zuweisung einer Rolle zu einem User fuer eine spezifische Ressource.
"""
id: str
user_id: str
role: Role
resource_type: ResourceType
resource_id: str
# Optionale Einschraenkungen
tenant_id: Optional[str] = None
namespace_id: Optional[str] = None
# Gueltigkeit
valid_from: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
valid_to: Optional[datetime] = None
# Metadaten
granted_by: str = ""
granted_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
revoked_at: Optional[datetime] = None
def is_active(self) -> bool:
now = datetime.now(timezone.utc)
if self.revoked_at:
return False
if self.valid_to and now > self.valid_to:
return False
return now >= self.valid_from
def to_dict(self):
return {
'id': self.id,
'user_id': self.user_id,
'role': self.role.value,
'resource_type': self.resource_type.value,
'resource_id': self.resource_id,
'tenant_id': self.tenant_id,
'namespace_id': self.namespace_id,
'valid_from': self.valid_from.isoformat(),
'valid_to': self.valid_to.isoformat() if self.valid_to else None,
'granted_by': self.granted_by,
'granted_at': self.granted_at.isoformat(),
'revoked_at': self.revoked_at.isoformat() if self.revoked_at else None,
'is_active': self.is_active()
}
@dataclass
class KeyShare:
"""
Berechtigung fuer einen User, auf verschluesselte Inhalte zuzugreifen.
Ein KeyShare ist KEIN Schluessel im Klartext, sondern eine
Berechtigung in Verbindung mit Role Assignment.
"""
id: str
user_id: str
package_id: str
# Berechtigungsumfang
permissions: Set[str] = field(default_factory=set)
# z.B. {"read_original", "read_eh", "read_ek_outputs", "write_annotations"}
# Optionale Einschraenkungen
scope: str = "full" # "full", "original_only", "eh_only", "outputs_only"
# Kette
granted_by: str = ""
granted_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
# Akzeptanz (fuer Invite-Flow)
invite_token: Optional[str] = None
accepted_at: Optional[datetime] = None
# Widerruf
revoked_at: Optional[datetime] = None
revoked_by: Optional[str] = None
def is_active(self) -> bool:
return self.revoked_at is None and (
self.invite_token is None or self.accepted_at is not None
)
def to_dict(self):
return {
'id': self.id,
'user_id': self.user_id,
'package_id': self.package_id,
'permissions': list(self.permissions),
'scope': self.scope,
'granted_by': self.granted_by,
'granted_at': self.granted_at.isoformat(),
'invite_token': self.invite_token,
'accepted_at': self.accepted_at.isoformat() if self.accepted_at else None,
'revoked_at': self.revoked_at.isoformat() if self.revoked_at else None,
'is_active': self.is_active()
}
@dataclass
class Tenant:
"""
Hoechste Isolationseinheit - typischerweise eine Schule.
"""
id: str
name: str
bundesland: str
tenant_type: str = "school" # "school", "pruefungszentrum", "behoerde"
# Verschluesselung
encryption_enabled: bool = True
# Metadaten
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
deleted_at: Optional[datetime] = None
def to_dict(self):
return {
'id': self.id,
'name': self.name,
'bundesland': self.bundesland,
'tenant_type': self.tenant_type,
'encryption_enabled': self.encryption_enabled,
'created_at': self.created_at.isoformat()
}
@dataclass
class Namespace:
"""
Arbeitsraum innerhalb eines Tenants.
z.B. "Abitur 2026 - Deutsch LK - Kurs 12a"
"""
id: str
tenant_id: str
name: str
# Kontext
jahr: int
fach: str
kurs: Optional[str] = None
pruefungsart: str = "abitur" # "abitur", "vorabitur"
# Policy
policy_set_id: Optional[str] = None
# Metadaten
created_by: str = ""
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
deleted_at: Optional[datetime] = None
def to_dict(self):
return {
'id': self.id,
'tenant_id': self.tenant_id,
'name': self.name,
'jahr': self.jahr,
'fach': self.fach,
'kurs': self.kurs,
'pruefungsart': self.pruefungsart,
'policy_set_id': self.policy_set_id,
'created_by': self.created_by,
'created_at': self.created_at.isoformat()
}
@dataclass
class ExamPackage:
"""
Pruefungspaket - kompletter Satz Arbeiten mit allen Artefakten.
"""
id: str
namespace_id: str
tenant_id: str
name: str
beschreibung: Optional[str] = None
# Workflow-Status
status: str = "draft" # "draft", "in_progress", "locked", "signed_off"
# Beteiligte (Rollen werden separat zugewiesen)
owner_id: str = "" # Typischerweise EK
# Verschluesselung
encryption_key_id: Optional[str] = None
# Timestamps
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
locked_at: Optional[datetime] = None
signed_off_at: Optional[datetime] = None
signed_off_by: Optional[str] = None
def to_dict(self):
return {
'id': self.id,
'namespace_id': self.namespace_id,
'tenant_id': self.tenant_id,
'name': self.name,
'beschreibung': self.beschreibung,
'status': self.status,
'owner_id': self.owner_id,
'created_at': self.created_at.isoformat(),
'locked_at': self.locked_at.isoformat() if self.locked_at else None,
'signed_off_at': self.signed_off_at.isoformat() if self.signed_off_at else None,
'signed_off_by': self.signed_off_by
}
# Backward-compat shim -- module moved to compliance/rbac_types.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("compliance.rbac_types")
@@ -0,0 +1,6 @@
"""
training package — training API, simulation, export, TrOCR.
Backward-compatible re-exports: consumers can still use
``from training_api import ...`` etc. via the shim files in backend/.
"""
+31
View File
@@ -0,0 +1,31 @@
"""
Training API — barrel re-export.
The actual code lives in:
- training_models.py (enums, Pydantic models, in-memory state)
- training_simulation.py (simulate_training_progress, SSE generators)
- training_routes.py (FastAPI router + all endpoints)
"""
# Models & enums
from .models import ( # noqa: F401
TrainingStatus,
ModelType,
TrainingConfig,
TrainingMetrics,
TrainingJob,
ModelVersion,
DatasetStats,
TrainingState,
_state,
)
# Simulation helpers
from .simulation import ( # noqa: F401
simulate_training_progress,
training_metrics_generator,
batch_ocr_progress_generator,
)
# Router
from .routes import router # noqa: F401
@@ -0,0 +1,448 @@
"""
Training Export Service for OCR Labeling Data
Exports labeled OCR data in formats suitable for fine-tuning:
- TrOCR (Microsoft's Transformer-based OCR model)
- llama3.2-vision (Meta's Vision-Language Model)
- Generic JSONL format
DATENSCHUTZ/PRIVACY:
- Alle Daten bleiben lokal auf dem Mac Mini
- Keine Cloud-Uploads ohne explizite Zustimmung
- Export-Pfade sind konfigurierbar
"""
import os
import json
import base64
import shutil
from pathlib import Path
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
from datetime import datetime
import hashlib
# Export directory configuration
EXPORT_BASE_PATH = os.getenv("OCR_EXPORT_PATH", "/app/ocr-exports")
TROCR_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "trocr")
LLAMA_VISION_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "llama-vision")
GENERIC_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "generic")
@dataclass
class TrainingSample:
"""A single training sample for OCR fine-tuning."""
id: str
image_path: str
ground_truth: str
ocr_text: Optional[str] = None
ocr_confidence: Optional[float] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class ExportResult:
"""Result of a training data export."""
export_format: str
export_path: str
sample_count: int
batch_id: str
created_at: datetime
manifest_path: str
class TrOCRExporter:
"""
Export training data for TrOCR fine-tuning.
TrOCR expects:
- Image files (PNG/JPG)
- A CSV/TSV file with: image_path, text
- Or a JSONL file with: {"file_name": "img.png", "text": "ground truth"}
We use the JSONL format for flexibility.
"""
def __init__(self, export_path: str = TROCR_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def export(
self,
samples: List[TrainingSample],
batch_id: str,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in TrOCR format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
export_data.append({
"file_name": image_ref,
"text": sample.ground_truth,
"id": sample.id,
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "train.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Write manifest
manifest = {
"format": "trocr",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data": "train.jsonl",
"images": "images/",
},
"model_config": {
"base_model": "microsoft/trocr-base-handwritten",
"task": "handwriting-recognition",
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="trocr",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class LlamaVisionExporter:
"""
Export training data for llama3.2-vision fine-tuning.
Llama Vision fine-tuning expects:
- JSONL format with base64-encoded images or image URLs
- Format: {"messages": [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "..."}]}, {"role": "assistant", "content": "..."}]}
We create a supervised fine-tuning dataset.
"""
def __init__(self, export_path: str = LLAMA_VISION_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def _encode_image_base64(self, image_path: str) -> Optional[str]:
"""Encode image to base64."""
try:
with open(image_path, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
except Exception:
return None
def export(
self,
samples: List[TrainingSample],
batch_id: str,
include_base64: bool = False,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in Llama Vision fine-tuning format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
include_base64: Whether to include base64-encoded images in JSONL
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# OCR instruction prompt
system_prompt = (
"Du bist ein OCR-Experte für deutsche Handschrift. "
"Lies den handgeschriebenen Text im Bild und gib ihn wortgetreu wieder."
)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
# Build message format
user_content = [
{"type": "image_url", "image_url": {"url": image_ref}},
{"type": "text", "text": "Lies den handgeschriebenen Text in diesem Bild."},
]
# Optionally include base64
if include_base64:
b64 = self._encode_image_base64(sample.image_path)
if b64:
ext = Path(sample.image_path).suffix.lower().replace('.', '')
mime = {'png': 'image/png', 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'}.get(ext, 'image/png')
user_content[0] = {
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"}
}
export_data.append({
"id": sample.id,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
{"role": "assistant", "content": sample.ground_truth},
],
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "train.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Write manifest
manifest = {
"format": "llama_vision",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data": "train.jsonl",
"images": "images/",
},
"model_config": {
"base_model": "llama3.2-vision:11b",
"task": "handwriting-ocr",
"system_prompt": system_prompt,
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="llama_vision",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class GenericExporter:
"""
Export training data in a generic JSONL format.
This format is compatible with most ML frameworks and can be
easily converted to other formats.
"""
def __init__(self, export_path: str = GENERIC_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def export(
self,
samples: List[TrainingSample],
batch_id: str,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in generic JSONL format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
export_data.append({
"id": sample.id,
"image_path": image_ref,
"ground_truth": sample.ground_truth,
"ocr_text": sample.ocr_text,
"ocr_confidence": sample.ocr_confidence,
"metadata": sample.metadata or {},
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "data.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Also write as single JSON for convenience
json_path = os.path.join(batch_path, "data.json")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, indent=2, ensure_ascii=False)
# Write manifest
manifest = {
"format": "generic",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data_jsonl": "data.jsonl",
"data_json": "data.json",
"images": "images/",
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="generic",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class TrainingExportService:
"""
Main service for exporting OCR labeling data to various training formats.
"""
def __init__(self):
self.trocr_exporter = TrOCRExporter()
self.llama_vision_exporter = LlamaVisionExporter()
self.generic_exporter = GenericExporter()
def export(
self,
samples: List[TrainingSample],
export_format: str,
batch_id: Optional[str] = None,
**kwargs,
) -> ExportResult:
"""
Export training samples in the specified format.
Args:
samples: List of training samples
export_format: 'trocr', 'llama_vision', or 'generic'
batch_id: Optional batch ID (generated if not provided)
**kwargs: Additional format-specific options
Returns:
ExportResult with export details
"""
if not batch_id:
batch_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
if export_format == "trocr":
return self.trocr_exporter.export(samples, batch_id, **kwargs)
elif export_format == "llama_vision":
return self.llama_vision_exporter.export(samples, batch_id, **kwargs)
elif export_format == "generic":
return self.generic_exporter.export(samples, batch_id, **kwargs)
else:
raise ValueError(f"Unknown export format: {export_format}")
def list_exports(self, export_format: Optional[str] = None) -> List[Dict]:
"""
List all available exports.
Args:
export_format: Optional filter by format
Returns:
List of export manifests
"""
exports = []
paths_to_check = []
if export_format is None or export_format == "trocr":
paths_to_check.append((TROCR_EXPORT_PATH, "trocr"))
if export_format is None or export_format == "llama_vision":
paths_to_check.append((LLAMA_VISION_EXPORT_PATH, "llama_vision"))
if export_format is None or export_format == "generic":
paths_to_check.append((GENERIC_EXPORT_PATH, "generic"))
for base_path, fmt in paths_to_check:
if not os.path.exists(base_path):
continue
for batch_dir in os.listdir(base_path):
manifest_path = os.path.join(base_path, batch_dir, "manifest.json")
if os.path.exists(manifest_path):
with open(manifest_path, 'r') as f:
manifest = json.load(f)
manifest["export_path"] = os.path.join(base_path, batch_dir)
exports.append(manifest)
return sorted(exports, key=lambda x: x.get("created_at", ""), reverse=True)
# Singleton instance
_export_service: Optional[TrainingExportService] = None
def get_training_export_service() -> TrainingExportService:
"""Get or create the training export service singleton."""
global _export_service
if _export_service is None:
_export_service = TrainingExportService()
return _export_service
+118
View File
@@ -0,0 +1,118 @@
"""
Training API — enums, request/response models, and in-memory state.
"""
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
from dataclasses import dataclass, field
from pydantic import BaseModel, Field
# ============================================================================
# ENUMS
# ============================================================================
class TrainingStatus(str, Enum):
QUEUED = "queued"
PREPARING = "preparing"
TRAINING = "training"
VALIDATING = "validating"
COMPLETED = "completed"
FAILED = "failed"
PAUSED = "paused"
CANCELLED = "cancelled"
class ModelType(str, Enum):
ZEUGNIS = "zeugnis"
KLAUSUR = "klausur"
GENERAL = "general"
# ============================================================================
# REQUEST/RESPONSE MODELS
# ============================================================================
class TrainingConfig(BaseModel):
"""Configuration for a training job."""
name: str = Field(..., description="Name for the training job")
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
batch_size: int = Field(16, ge=1, le=128)
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
epochs: int = Field(10, ge=1, le=100)
warmup_steps: int = Field(500, ge=0, le=10000)
weight_decay: float = Field(0.01, ge=0, le=1)
gradient_accumulation: int = Field(4, ge=1, le=32)
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
class TrainingMetrics(BaseModel):
"""Metrics from a training job."""
precision: float = 0.0
recall: float = 0.0
f1_score: float = 0.0
accuracy: float = 0.0
loss_history: List[float] = []
val_loss_history: List[float] = []
class TrainingJob(BaseModel):
"""A training job with full details."""
id: str
name: str
model_type: ModelType
status: TrainingStatus
progress: float
current_epoch: int
total_epochs: int
loss: float
val_loss: float
learning_rate: float
documents_processed: int
total_documents: int
started_at: Optional[datetime]
estimated_completion: Optional[datetime]
completed_at: Optional[datetime]
error_message: Optional[str]
metrics: TrainingMetrics
config: TrainingConfig
class ModelVersion(BaseModel):
"""A trained model version."""
id: str
job_id: str
version: str
model_type: ModelType
created_at: datetime
metrics: TrainingMetrics
is_active: bool
size_mb: float
bundeslaender: List[str]
class DatasetStats(BaseModel):
"""Statistics about the training dataset."""
total_documents: int
total_chunks: int
training_allowed: int
by_bundesland: Dict[str, int]
by_doc_type: Dict[str, int]
# ============================================================================
# IN-MEMORY STATE (Replace with database in production)
# ============================================================================
@dataclass
class TrainingState:
"""Global training state."""
jobs: Dict[str, dict] = field(default_factory=dict)
model_versions: Dict[str, dict] = field(default_factory=dict)
active_job_id: Optional[str] = None
_state = TrainingState()
+303
View File
@@ -0,0 +1,303 @@
"""
Training API — FastAPI route handlers.
"""
import uuid
from datetime import datetime
from typing import List
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from .models import (
TrainingStatus,
TrainingConfig,
_state,
)
from .simulation import (
simulate_training_progress,
training_metrics_generator,
batch_ocr_progress_generator,
)
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
# ============================================================================
# TRAINING JOBS
# ============================================================================
@router.get("/jobs", response_model=List[dict])
async def list_training_jobs():
"""Get all training jobs."""
return list(_state.jobs.values())
@router.get("/jobs/{job_id}", response_model=dict)
async def get_training_job(job_id: str):
"""Get details for a specific training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return _state.jobs[job_id]
@router.post("/jobs", response_model=dict)
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
"""Create and start a new training job."""
# Check if there's already an active job
if _state.active_job_id:
active_job = _state.jobs.get(_state.active_job_id)
if active_job and active_job["status"] in [
TrainingStatus.TRAINING.value,
TrainingStatus.PREPARING.value,
]:
raise HTTPException(
status_code=409,
detail="Another training job is already running"
)
# Create job
job_id = str(uuid.uuid4())
job = {
"id": job_id,
"name": config.name,
"model_type": config.model_type.value,
"status": TrainingStatus.QUEUED.value,
"progress": 0,
"current_epoch": 0,
"total_epochs": config.epochs,
"loss": 1.0,
"val_loss": 1.0,
"learning_rate": config.learning_rate,
"documents_processed": 0,
"total_documents": len(config.bundeslaender) * 50, # Estimate
"started_at": None,
"estimated_completion": None,
"completed_at": None,
"error_message": None,
"metrics": {
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
"accuracy": 0.0,
"loss_history": [],
"val_loss_history": [],
},
"config": config.dict(),
}
_state.jobs[job_id] = job
_state.active_job_id = job_id
# Start training in background
background_tasks.add_task(simulate_training_progress, job_id)
return {"id": job_id, "status": "queued", "message": "Training job created"}
@router.post("/jobs/{job_id}/pause", response_model=dict)
async def pause_training_job(job_id: str):
"""Pause a running training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Job is not running")
job["status"] = TrainingStatus.PAUSED.value
return {"success": True, "message": "Training paused"}
@router.post("/jobs/{job_id}/resume", response_model=dict)
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
"""Resume a paused training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.PAUSED.value:
raise HTTPException(status_code=400, detail="Job is not paused")
job["status"] = TrainingStatus.TRAINING.value
_state.active_job_id = job_id
background_tasks.add_task(simulate_training_progress, job_id)
return {"success": True, "message": "Training resumed"}
@router.post("/jobs/{job_id}/cancel", response_model=dict)
async def cancel_training_job(job_id: str):
"""Cancel a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
job["status"] = TrainingStatus.CANCELLED.value
job["completed_at"] = datetime.now().isoformat()
if _state.active_job_id == job_id:
_state.active_job_id = None
return {"success": True, "message": "Training cancelled"}
@router.delete("/jobs/{job_id}", response_model=dict)
async def delete_training_job(job_id: str):
"""Delete a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] == TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Cannot delete running job")
del _state.jobs[job_id]
return {"success": True, "message": "Job deleted"}
# ============================================================================
# MODEL VERSIONS
# ============================================================================
@router.get("/models", response_model=List[dict])
async def list_model_versions():
"""Get all trained model versions."""
return list(_state.model_versions.values())
@router.get("/models/{version_id}", response_model=dict)
async def get_model_version(version_id: str):
"""Get details for a specific model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
return _state.model_versions[version_id]
@router.post("/models/{version_id}/activate", response_model=dict)
async def activate_model_version(version_id: str):
"""Set a model version as active."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
# Deactivate all other versions of same type
model = _state.model_versions[version_id]
for v in _state.model_versions.values():
if v["model_type"] == model["model_type"]:
v["is_active"] = False
model["is_active"] = True
return {"success": True, "message": "Model activated"}
@router.delete("/models/{version_id}", response_model=dict)
async def delete_model_version(version_id: str):
"""Delete a model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
model = _state.model_versions[version_id]
if model["is_active"]:
raise HTTPException(status_code=400, detail="Cannot delete active model")
del _state.model_versions[version_id]
return {"success": True, "message": "Model deleted"}
# ============================================================================
# DATASET STATS & STATUS
# ============================================================================
@router.get("/dataset/stats", response_model=dict)
async def get_dataset_stats():
"""Get statistics about the training dataset."""
from metrics_db import get_zeugnis_stats
zeugnis_stats = await get_zeugnis_stats()
return {
"total_documents": zeugnis_stats.get("total_documents", 0),
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12,
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
"by_bundesland": {
bl["bundesland"]: bl.get("doc_count", 0)
for bl in zeugnis_stats.get("per_bundesland", [])
},
"by_doc_type": {
"verordnung": 150,
"schulordnung": 80,
"handreichung": 45,
"erlass": 30,
},
}
@router.get("/status", response_model=dict)
async def get_training_status():
"""Get overall training system status."""
active_job = None
if _state.active_job_id and _state.active_job_id in _state.jobs:
active_job = _state.jobs[_state.active_job_id]
return {
"is_training": _state.active_job_id is not None and active_job is not None and
active_job["status"] == TrainingStatus.TRAINING.value,
"active_job_id": _state.active_job_id,
"total_jobs": len(_state.jobs),
"completed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.COMPLETED.value
),
"failed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.FAILED.value
),
"model_versions": len(_state.model_versions),
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
}
# ============================================================================
# SSE ENDPOINTS
# ============================================================================
@router.get("/metrics/stream")
async def stream_training_metrics(job_id: str, request: Request):
"""
SSE endpoint for streaming training metrics.
Streams real-time training progress for a specific job.
"""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return StreamingResponse(
training_metrics_generator(job_id, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@router.get("/ocr/stream")
async def stream_batch_ocr(images_count: int, request: Request):
"""
SSE endpoint for streaming batch OCR progress.
Simulates batch OCR processing with progress updates.
"""
if images_count < 1 or images_count > 100:
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
return StreamingResponse(
batch_ocr_progress_generator(images_count, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@@ -0,0 +1,190 @@
"""
Training API — simulation helper and SSE generators.
"""
import json
import uuid
import asyncio
from datetime import datetime, timedelta
from .models import TrainingStatus, _state
async def simulate_training_progress(job_id: str):
"""Simulate training progress (replace with actual training logic)."""
if job_id not in _state.jobs:
return
job = _state.jobs[job_id]
job["status"] = TrainingStatus.TRAINING.value
job["started_at"] = datetime.now().isoformat()
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
current_step = 0
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
# Update progress
progress = (current_step / total_steps) * 100
current_epoch = current_step // 100 + 1
# Simulate decreasing loss
base_loss = 0.8 * (1 - progress / 100) + 0.1
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
val_loss = loss * 1.1
# Update job state
job["progress"] = progress
job["current_epoch"] = min(current_epoch, job["total_epochs"])
job["loss"] = round(loss, 4)
job["val_loss"] = round(val_loss, 4)
job["documents_processed"] = int((progress / 100) * job["total_documents"])
# Update metrics
job["metrics"]["loss_history"].append(round(loss, 4))
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
# Keep only last 50 history points
if len(job["metrics"]["loss_history"]) > 50:
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
# Estimate completion
if progress > 0:
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
remaining = (elapsed / progress) * (100 - progress)
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
current_step += 1
await asyncio.sleep(0.5) # Simulate work
# Mark as completed
if job["status"] == TrainingStatus.TRAINING.value:
job["status"] = TrainingStatus.COMPLETED.value
job["progress"] = 100
job["completed_at"] = datetime.now().isoformat()
# Create model version
version_id = str(uuid.uuid4())
_state.model_versions[version_id] = {
"id": version_id,
"job_id": job_id,
"version": f"v{len(_state.model_versions) + 1}.0",
"model_type": job["model_type"],
"created_at": datetime.now().isoformat(),
"metrics": job["metrics"],
"is_active": True,
"size_mb": 245.7,
"bundeslaender": job["config"]["bundeslaender"],
}
_state.active_job_id = None
async def training_metrics_generator(job_id: str, request):
"""
SSE generator for streaming training metrics.
Yields JSON-encoded training status updates every 500ms.
"""
while True:
# Check if client disconnected
if await request.is_disconnected():
break
# Get job status
if job_id not in _state.jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
break
job = _state.jobs[job_id]
# Build metrics response
metrics_data = {
"job_id": job["id"],
"status": job["status"],
"progress": job["progress"],
"current_epoch": job["current_epoch"],
"total_epochs": job["total_epochs"],
"current_step": int(job["progress"] * job["total_epochs"]),
"total_steps": job["total_epochs"] * 100,
"elapsed_time_ms": 0,
"estimated_remaining_ms": 0,
"metrics": {
"loss": job["loss"],
"val_loss": job["val_loss"],
"accuracy": job["metrics"]["accuracy"],
"learning_rate": job["learning_rate"]
},
"history": [
{
"epoch": i + 1,
"step": (i + 1) * 10,
"loss": loss,
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
"learning_rate": job["learning_rate"],
"timestamp": 0
}
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
]
}
# Calculate elapsed time
if job["started_at"]:
started = datetime.fromisoformat(job["started_at"])
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
# Calculate remaining time
if job["estimated_completion"]:
estimated = datetime.fromisoformat(job["estimated_completion"])
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
# Send SSE event
yield f"data: {json.dumps(metrics_data)}\n\n"
# Check if job completed
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
break
# Wait before next update
await asyncio.sleep(0.5)
async def batch_ocr_progress_generator(images_count: int, request):
"""
SSE generator for batch OCR progress simulation.
In production, this would integrate with actual OCR processing.
"""
import random
for i in range(images_count):
# Check if client disconnected
if await request.is_disconnected():
break
# Simulate processing time
await asyncio.sleep(random.uniform(0.3, 0.8))
progress_data = {
"type": "progress",
"current": i + 1,
"total": images_count,
"progress_percent": ((i + 1) / images_count) * 100,
"elapsed_ms": (i + 1) * 500,
"estimated_remaining_ms": (images_count - i - 1) * 500,
"result": {
"text": f"Sample recognized text for image {i + 1}",
"confidence": round(random.uniform(0.7, 0.98), 2),
"processing_time_ms": random.randint(200, 600),
"from_cache": random.random() < 0.2
}
}
yield f"data: {json.dumps(progress_data)}\n\n"
# Send completion event
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"
@@ -0,0 +1,261 @@
"""
TrOCR API - REST endpoints for TrOCR handwriting OCR.
Provides:
- /ocr/trocr - Single image OCR
- /ocr/trocr/batch - Batch image processing
- /ocr/trocr/status - Model status
- /ocr/trocr/cache - Cache statistics
"""
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import json
import logging
from services.trocr_service import (
run_trocr_ocr_enhanced,
run_trocr_batch,
run_trocr_batch_stream,
get_model_status,
get_cache_stats,
preload_trocr_model,
OCRResult,
BatchOCRResult
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
# =============================================================================
# MODELS
# =============================================================================
class TrOCRResponse(BaseModel):
"""Response model for single image OCR."""
text: str = Field(..., description="Extracted text")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
model: str = Field(..., description="Model used for OCR")
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
from_cache: bool = Field(False, description="Whether result was from cache")
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
word_count: int = Field(0, description="Number of words detected")
class BatchOCRResponse(BaseModel):
"""Response model for batch OCR."""
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
total_time_ms: int = Field(..., ge=0, description="Total processing time")
processed_count: int = Field(..., ge=0, description="Number of images processed")
cached_count: int = Field(0, description="Number of results from cache")
error_count: int = Field(0, description="Number of errors")
class ModelStatusResponse(BaseModel):
"""Response model for model status."""
status: str = Field(..., description="Model status: available, not_installed")
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
model_name: Optional[str] = Field(None, description="Name of loaded model")
device: Optional[str] = Field(None, description="Device model is running on")
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
class CacheStatsResponse(BaseModel):
"""Response model for cache statistics."""
size: int = Field(..., ge=0, description="Current cache size")
max_size: int = Field(..., ge=0, description="Maximum cache size")
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
# =============================================================================
# ENDPOINTS
# =============================================================================
@router.get("/status", response_model=ModelStatusResponse)
async def get_trocr_status():
"""
Get TrOCR model status.
Returns information about whether the model is loaded and available.
"""
return get_model_status()
@router.get("/cache", response_model=CacheStatsResponse)
async def get_trocr_cache_stats():
"""
Get TrOCR cache statistics.
Returns information about the OCR result cache.
"""
return get_cache_stats()
@router.post("/preload")
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
"""
Preload TrOCR model into memory.
This speeds up the first OCR request by loading the model ahead of time.
"""
success = preload_trocr_model(handwritten=handwritten)
if success:
return {"status": "success", "message": "Model preloaded successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to preload model")
@router.post("", response_model=TrOCRResponse)
async def run_trocr(
file: UploadFile = File(..., description="Image file to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split image into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on a single image.
Supports PNG, JPG, and other common image formats.
"""
# Validate file type
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_data = await file.read()
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return TrOCRResponse(
text=result.text,
confidence=result.confidence,
processing_time_ms=result.processing_time_ms,
model=result.model,
has_lora_adapter=result.has_lora_adapter,
from_cache=result.from_cache,
image_hash=result.image_hash,
word_count=len(result.text.split()) if result.text else 0
)
except Exception as e:
logger.error(f"TrOCR API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=BatchOCRResponse)
async def run_trocr_batch_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images.
Processes images sequentially and returns all results.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
batch_result = await run_trocr_batch(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return BatchOCRResponse(
results=[
TrOCRResponse(
text=r.text,
confidence=r.confidence,
processing_time_ms=r.processing_time_ms,
model=r.model,
has_lora_adapter=r.has_lora_adapter,
from_cache=r.from_cache,
image_hash=r.image_hash,
word_count=len(r.text.split()) if r.text else 0
)
for r in batch_result.results
],
total_time_ms=batch_result.total_time_ms,
processed_count=batch_result.processed_count,
cached_count=batch_result.cached_count,
error_count=batch_result.error_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR batch API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch/stream")
async def run_trocr_batch_stream_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
Returns a stream of progress events as images are processed.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
async def event_generator():
async for update in run_trocr_batch_stream(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
):
yield f"data: {json.dumps(update)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR stream API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
+4 -31
View File
@@ -1,31 +1,4 @@
"""
Training API — barrel re-export.
The actual code lives in:
- training_models.py (enums, Pydantic models, in-memory state)
- training_simulation.py (simulate_training_progress, SSE generators)
- training_routes.py (FastAPI router + all endpoints)
"""
# Models & enums
from training_models import ( # noqa: F401
TrainingStatus,
ModelType,
TrainingConfig,
TrainingMetrics,
TrainingJob,
ModelVersion,
DatasetStats,
TrainingState,
_state,
)
# Simulation helpers
from training_simulation import ( # noqa: F401
simulate_training_progress,
training_metrics_generator,
batch_ocr_progress_generator,
)
# Router
from training_routes import router # noqa: F401
# Backward-compat shim -- module moved to training/api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("training.api")
@@ -1,448 +1,4 @@
"""
Training Export Service for OCR Labeling Data
Exports labeled OCR data in formats suitable for fine-tuning:
- TrOCR (Microsoft's Transformer-based OCR model)
- llama3.2-vision (Meta's Vision-Language Model)
- Generic JSONL format
DATENSCHUTZ/PRIVACY:
- Alle Daten bleiben lokal auf dem Mac Mini
- Keine Cloud-Uploads ohne explizite Zustimmung
- Export-Pfade sind konfigurierbar
"""
import os
import json
import base64
import shutil
from pathlib import Path
from typing import List, Dict, Optional, Any
from dataclasses import dataclass
from datetime import datetime
import hashlib
# Export directory configuration
EXPORT_BASE_PATH = os.getenv("OCR_EXPORT_PATH", "/app/ocr-exports")
TROCR_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "trocr")
LLAMA_VISION_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "llama-vision")
GENERIC_EXPORT_PATH = os.path.join(EXPORT_BASE_PATH, "generic")
@dataclass
class TrainingSample:
"""A single training sample for OCR fine-tuning."""
id: str
image_path: str
ground_truth: str
ocr_text: Optional[str] = None
ocr_confidence: Optional[float] = None
metadata: Optional[Dict[str, Any]] = None
@dataclass
class ExportResult:
"""Result of a training data export."""
export_format: str
export_path: str
sample_count: int
batch_id: str
created_at: datetime
manifest_path: str
class TrOCRExporter:
"""
Export training data for TrOCR fine-tuning.
TrOCR expects:
- Image files (PNG/JPG)
- A CSV/TSV file with: image_path, text
- Or a JSONL file with: {"file_name": "img.png", "text": "ground truth"}
We use the JSONL format for flexibility.
"""
def __init__(self, export_path: str = TROCR_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def export(
self,
samples: List[TrainingSample],
batch_id: str,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in TrOCR format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
export_data.append({
"file_name": image_ref,
"text": sample.ground_truth,
"id": sample.id,
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "train.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Write manifest
manifest = {
"format": "trocr",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data": "train.jsonl",
"images": "images/",
},
"model_config": {
"base_model": "microsoft/trocr-base-handwritten",
"task": "handwriting-recognition",
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="trocr",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class LlamaVisionExporter:
"""
Export training data for llama3.2-vision fine-tuning.
Llama Vision fine-tuning expects:
- JSONL format with base64-encoded images or image URLs
- Format: {"messages": [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "..."}]}, {"role": "assistant", "content": "..."}]}
We create a supervised fine-tuning dataset.
"""
def __init__(self, export_path: str = LLAMA_VISION_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def _encode_image_base64(self, image_path: str) -> Optional[str]:
"""Encode image to base64."""
try:
with open(image_path, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
except Exception:
return None
def export(
self,
samples: List[TrainingSample],
batch_id: str,
include_base64: bool = False,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in Llama Vision fine-tuning format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
include_base64: Whether to include base64-encoded images in JSONL
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# OCR instruction prompt
system_prompt = (
"Du bist ein OCR-Experte für deutsche Handschrift. "
"Lies den handgeschriebenen Text im Bild und gib ihn wortgetreu wieder."
)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
# Build message format
user_content = [
{"type": "image_url", "image_url": {"url": image_ref}},
{"type": "text", "text": "Lies den handgeschriebenen Text in diesem Bild."},
]
# Optionally include base64
if include_base64:
b64 = self._encode_image_base64(sample.image_path)
if b64:
ext = Path(sample.image_path).suffix.lower().replace('.', '')
mime = {'png': 'image/png', 'jpg': 'image/jpeg', 'jpeg': 'image/jpeg'}.get(ext, 'image/png')
user_content[0] = {
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"}
}
export_data.append({
"id": sample.id,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
{"role": "assistant", "content": sample.ground_truth},
],
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "train.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Write manifest
manifest = {
"format": "llama_vision",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data": "train.jsonl",
"images": "images/",
},
"model_config": {
"base_model": "llama3.2-vision:11b",
"task": "handwriting-ocr",
"system_prompt": system_prompt,
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="llama_vision",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class GenericExporter:
"""
Export training data in a generic JSONL format.
This format is compatible with most ML frameworks and can be
easily converted to other formats.
"""
def __init__(self, export_path: str = GENERIC_EXPORT_PATH):
self.export_path = export_path
os.makedirs(export_path, exist_ok=True)
def export(
self,
samples: List[TrainingSample],
batch_id: str,
copy_images: bool = True,
) -> ExportResult:
"""
Export samples in generic JSONL format.
Args:
samples: List of training samples
batch_id: Unique batch identifier
copy_images: Whether to copy images to export directory
Returns:
ExportResult with export details
"""
batch_path = os.path.join(self.export_path, batch_id)
images_path = os.path.join(batch_path, "images")
os.makedirs(images_path, exist_ok=True)
# Export data
export_data = []
for sample in samples:
# Copy image if requested
if copy_images and os.path.exists(sample.image_path):
image_filename = f"{sample.id}{Path(sample.image_path).suffix}"
dest_path = os.path.join(images_path, image_filename)
shutil.copy2(sample.image_path, dest_path)
image_ref = f"images/{image_filename}"
else:
image_ref = sample.image_path
export_data.append({
"id": sample.id,
"image_path": image_ref,
"ground_truth": sample.ground_truth,
"ocr_text": sample.ocr_text,
"ocr_confidence": sample.ocr_confidence,
"metadata": sample.metadata or {},
})
# Write JSONL file
jsonl_path = os.path.join(batch_path, "data.jsonl")
with open(jsonl_path, 'w', encoding='utf-8') as f:
for item in export_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# Also write as single JSON for convenience
json_path = os.path.join(batch_path, "data.json")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, indent=2, ensure_ascii=False)
# Write manifest
manifest = {
"format": "generic",
"version": "1.0",
"batch_id": batch_id,
"sample_count": len(samples),
"created_at": datetime.utcnow().isoformat(),
"files": {
"data_jsonl": "data.jsonl",
"data_json": "data.json",
"images": "images/",
},
}
manifest_path = os.path.join(batch_path, "manifest.json")
with open(manifest_path, 'w') as f:
json.dump(manifest, f, indent=2)
return ExportResult(
export_format="generic",
export_path=batch_path,
sample_count=len(samples),
batch_id=batch_id,
created_at=datetime.utcnow(),
manifest_path=manifest_path,
)
class TrainingExportService:
"""
Main service for exporting OCR labeling data to various training formats.
"""
def __init__(self):
self.trocr_exporter = TrOCRExporter()
self.llama_vision_exporter = LlamaVisionExporter()
self.generic_exporter = GenericExporter()
def export(
self,
samples: List[TrainingSample],
export_format: str,
batch_id: Optional[str] = None,
**kwargs,
) -> ExportResult:
"""
Export training samples in the specified format.
Args:
samples: List of training samples
export_format: 'trocr', 'llama_vision', or 'generic'
batch_id: Optional batch ID (generated if not provided)
**kwargs: Additional format-specific options
Returns:
ExportResult with export details
"""
if not batch_id:
batch_id = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
if export_format == "trocr":
return self.trocr_exporter.export(samples, batch_id, **kwargs)
elif export_format == "llama_vision":
return self.llama_vision_exporter.export(samples, batch_id, **kwargs)
elif export_format == "generic":
return self.generic_exporter.export(samples, batch_id, **kwargs)
else:
raise ValueError(f"Unknown export format: {export_format}")
def list_exports(self, export_format: Optional[str] = None) -> List[Dict]:
"""
List all available exports.
Args:
export_format: Optional filter by format
Returns:
List of export manifests
"""
exports = []
paths_to_check = []
if export_format is None or export_format == "trocr":
paths_to_check.append((TROCR_EXPORT_PATH, "trocr"))
if export_format is None or export_format == "llama_vision":
paths_to_check.append((LLAMA_VISION_EXPORT_PATH, "llama_vision"))
if export_format is None or export_format == "generic":
paths_to_check.append((GENERIC_EXPORT_PATH, "generic"))
for base_path, fmt in paths_to_check:
if not os.path.exists(base_path):
continue
for batch_dir in os.listdir(base_path):
manifest_path = os.path.join(base_path, batch_dir, "manifest.json")
if os.path.exists(manifest_path):
with open(manifest_path, 'r') as f:
manifest = json.load(f)
manifest["export_path"] = os.path.join(base_path, batch_dir)
exports.append(manifest)
return sorted(exports, key=lambda x: x.get("created_at", ""), reverse=True)
# Singleton instance
_export_service: Optional[TrainingExportService] = None
def get_training_export_service() -> TrainingExportService:
"""Get or create the training export service singleton."""
global _export_service
if _export_service is None:
_export_service = TrainingExportService()
return _export_service
# Backward-compat shim -- module moved to training/export_service.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("training.export_service")
+4 -118
View File
@@ -1,118 +1,4 @@
"""
Training API — enums, request/response models, and in-memory state.
"""
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
from enum import Enum
from dataclasses import dataclass, field
from pydantic import BaseModel, Field
# ============================================================================
# ENUMS
# ============================================================================
class TrainingStatus(str, Enum):
QUEUED = "queued"
PREPARING = "preparing"
TRAINING = "training"
VALIDATING = "validating"
COMPLETED = "completed"
FAILED = "failed"
PAUSED = "paused"
CANCELLED = "cancelled"
class ModelType(str, Enum):
ZEUGNIS = "zeugnis"
KLAUSUR = "klausur"
GENERAL = "general"
# ============================================================================
# REQUEST/RESPONSE MODELS
# ============================================================================
class TrainingConfig(BaseModel):
"""Configuration for a training job."""
name: str = Field(..., description="Name for the training job")
model_type: ModelType = Field(ModelType.ZEUGNIS, description="Type of model to train")
bundeslaender: List[str] = Field(..., description="List of Bundesland codes to include")
batch_size: int = Field(16, ge=1, le=128)
learning_rate: float = Field(0.00005, ge=0.000001, le=0.1)
epochs: int = Field(10, ge=1, le=100)
warmup_steps: int = Field(500, ge=0, le=10000)
weight_decay: float = Field(0.01, ge=0, le=1)
gradient_accumulation: int = Field(4, ge=1, le=32)
mixed_precision: bool = Field(True, description="Use FP16 mixed precision training")
class TrainingMetrics(BaseModel):
"""Metrics from a training job."""
precision: float = 0.0
recall: float = 0.0
f1_score: float = 0.0
accuracy: float = 0.0
loss_history: List[float] = []
val_loss_history: List[float] = []
class TrainingJob(BaseModel):
"""A training job with full details."""
id: str
name: str
model_type: ModelType
status: TrainingStatus
progress: float
current_epoch: int
total_epochs: int
loss: float
val_loss: float
learning_rate: float
documents_processed: int
total_documents: int
started_at: Optional[datetime]
estimated_completion: Optional[datetime]
completed_at: Optional[datetime]
error_message: Optional[str]
metrics: TrainingMetrics
config: TrainingConfig
class ModelVersion(BaseModel):
"""A trained model version."""
id: str
job_id: str
version: str
model_type: ModelType
created_at: datetime
metrics: TrainingMetrics
is_active: bool
size_mb: float
bundeslaender: List[str]
class DatasetStats(BaseModel):
"""Statistics about the training dataset."""
total_documents: int
total_chunks: int
training_allowed: int
by_bundesland: Dict[str, int]
by_doc_type: Dict[str, int]
# ============================================================================
# IN-MEMORY STATE (Replace with database in production)
# ============================================================================
@dataclass
class TrainingState:
"""Global training state."""
jobs: Dict[str, dict] = field(default_factory=dict)
model_versions: Dict[str, dict] = field(default_factory=dict)
active_job_id: Optional[str] = None
_state = TrainingState()
# Backward-compat shim -- module moved to training/models.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("training.models")
+4 -303
View File
@@ -1,303 +1,4 @@
"""
Training API — FastAPI route handlers.
"""
import uuid
from datetime import datetime
from typing import List
from fastapi import APIRouter, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse
from training_models import (
TrainingStatus,
TrainingConfig,
_state,
)
from training_simulation import (
simulate_training_progress,
training_metrics_generator,
batch_ocr_progress_generator,
)
router = APIRouter(prefix="/api/v1/admin/training", tags=["Training"])
# ============================================================================
# TRAINING JOBS
# ============================================================================
@router.get("/jobs", response_model=List[dict])
async def list_training_jobs():
"""Get all training jobs."""
return list(_state.jobs.values())
@router.get("/jobs/{job_id}", response_model=dict)
async def get_training_job(job_id: str):
"""Get details for a specific training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return _state.jobs[job_id]
@router.post("/jobs", response_model=dict)
async def create_training_job(config: TrainingConfig, background_tasks: BackgroundTasks):
"""Create and start a new training job."""
# Check if there's already an active job
if _state.active_job_id:
active_job = _state.jobs.get(_state.active_job_id)
if active_job and active_job["status"] in [
TrainingStatus.TRAINING.value,
TrainingStatus.PREPARING.value,
]:
raise HTTPException(
status_code=409,
detail="Another training job is already running"
)
# Create job
job_id = str(uuid.uuid4())
job = {
"id": job_id,
"name": config.name,
"model_type": config.model_type.value,
"status": TrainingStatus.QUEUED.value,
"progress": 0,
"current_epoch": 0,
"total_epochs": config.epochs,
"loss": 1.0,
"val_loss": 1.0,
"learning_rate": config.learning_rate,
"documents_processed": 0,
"total_documents": len(config.bundeslaender) * 50, # Estimate
"started_at": None,
"estimated_completion": None,
"completed_at": None,
"error_message": None,
"metrics": {
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
"accuracy": 0.0,
"loss_history": [],
"val_loss_history": [],
},
"config": config.dict(),
}
_state.jobs[job_id] = job
_state.active_job_id = job_id
# Start training in background
background_tasks.add_task(simulate_training_progress, job_id)
return {"id": job_id, "status": "queued", "message": "Training job created"}
@router.post("/jobs/{job_id}/pause", response_model=dict)
async def pause_training_job(job_id: str):
"""Pause a running training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Job is not running")
job["status"] = TrainingStatus.PAUSED.value
return {"success": True, "message": "Training paused"}
@router.post("/jobs/{job_id}/resume", response_model=dict)
async def resume_training_job(job_id: str, background_tasks: BackgroundTasks):
"""Resume a paused training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] != TrainingStatus.PAUSED.value:
raise HTTPException(status_code=400, detail="Job is not paused")
job["status"] = TrainingStatus.TRAINING.value
_state.active_job_id = job_id
background_tasks.add_task(simulate_training_progress, job_id)
return {"success": True, "message": "Training resumed"}
@router.post("/jobs/{job_id}/cancel", response_model=dict)
async def cancel_training_job(job_id: str):
"""Cancel a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
job["status"] = TrainingStatus.CANCELLED.value
job["completed_at"] = datetime.now().isoformat()
if _state.active_job_id == job_id:
_state.active_job_id = None
return {"success": True, "message": "Training cancelled"}
@router.delete("/jobs/{job_id}", response_model=dict)
async def delete_training_job(job_id: str):
"""Delete a training job."""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
job = _state.jobs[job_id]
if job["status"] == TrainingStatus.TRAINING.value:
raise HTTPException(status_code=400, detail="Cannot delete running job")
del _state.jobs[job_id]
return {"success": True, "message": "Job deleted"}
# ============================================================================
# MODEL VERSIONS
# ============================================================================
@router.get("/models", response_model=List[dict])
async def list_model_versions():
"""Get all trained model versions."""
return list(_state.model_versions.values())
@router.get("/models/{version_id}", response_model=dict)
async def get_model_version(version_id: str):
"""Get details for a specific model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
return _state.model_versions[version_id]
@router.post("/models/{version_id}/activate", response_model=dict)
async def activate_model_version(version_id: str):
"""Set a model version as active."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
# Deactivate all other versions of same type
model = _state.model_versions[version_id]
for v in _state.model_versions.values():
if v["model_type"] == model["model_type"]:
v["is_active"] = False
model["is_active"] = True
return {"success": True, "message": "Model activated"}
@router.delete("/models/{version_id}", response_model=dict)
async def delete_model_version(version_id: str):
"""Delete a model version."""
if version_id not in _state.model_versions:
raise HTTPException(status_code=404, detail="Model version not found")
model = _state.model_versions[version_id]
if model["is_active"]:
raise HTTPException(status_code=400, detail="Cannot delete active model")
del _state.model_versions[version_id]
return {"success": True, "message": "Model deleted"}
# ============================================================================
# DATASET STATS & STATUS
# ============================================================================
@router.get("/dataset/stats", response_model=dict)
async def get_dataset_stats():
"""Get statistics about the training dataset."""
from metrics_db import get_zeugnis_stats
zeugnis_stats = await get_zeugnis_stats()
return {
"total_documents": zeugnis_stats.get("total_documents", 0),
"total_chunks": zeugnis_stats.get("total_documents", 0) * 12,
"training_allowed": zeugnis_stats.get("training_allowed_documents", 0),
"by_bundesland": {
bl["bundesland"]: bl.get("doc_count", 0)
for bl in zeugnis_stats.get("per_bundesland", [])
},
"by_doc_type": {
"verordnung": 150,
"schulordnung": 80,
"handreichung": 45,
"erlass": 30,
},
}
@router.get("/status", response_model=dict)
async def get_training_status():
"""Get overall training system status."""
active_job = None
if _state.active_job_id and _state.active_job_id in _state.jobs:
active_job = _state.jobs[_state.active_job_id]
return {
"is_training": _state.active_job_id is not None and active_job is not None and
active_job["status"] == TrainingStatus.TRAINING.value,
"active_job_id": _state.active_job_id,
"total_jobs": len(_state.jobs),
"completed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.COMPLETED.value
),
"failed_jobs": sum(
1 for j in _state.jobs.values()
if j["status"] == TrainingStatus.FAILED.value
),
"model_versions": len(_state.model_versions),
"active_models": sum(1 for m in _state.model_versions.values() if m["is_active"]),
}
# ============================================================================
# SSE ENDPOINTS
# ============================================================================
@router.get("/metrics/stream")
async def stream_training_metrics(job_id: str, request: Request):
"""
SSE endpoint for streaming training metrics.
Streams real-time training progress for a specific job.
"""
if job_id not in _state.jobs:
raise HTTPException(status_code=404, detail="Job not found")
return StreamingResponse(
training_metrics_generator(job_id, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
@router.get("/ocr/stream")
async def stream_batch_ocr(images_count: int, request: Request):
"""
SSE endpoint for streaming batch OCR progress.
Simulates batch OCR processing with progress updates.
"""
if images_count < 1 or images_count > 100:
raise HTTPException(status_code=400, detail="images_count must be between 1 and 100")
return StreamingResponse(
batch_ocr_progress_generator(images_count, request),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# Backward-compat shim -- module moved to training/routes.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("training.routes")
+4 -190
View File
@@ -1,190 +1,4 @@
"""
Training API — simulation helper and SSE generators.
"""
import json
import uuid
import asyncio
from datetime import datetime, timedelta
from training_models import TrainingStatus, _state
async def simulate_training_progress(job_id: str):
"""Simulate training progress (replace with actual training logic)."""
if job_id not in _state.jobs:
return
job = _state.jobs[job_id]
job["status"] = TrainingStatus.TRAINING.value
job["started_at"] = datetime.now().isoformat()
total_steps = job["total_epochs"] * 100 # Simulate 100 steps per epoch
current_step = 0
while current_step < total_steps and job["status"] == TrainingStatus.TRAINING.value:
# Update progress
progress = (current_step / total_steps) * 100
current_epoch = current_step // 100 + 1
# Simulate decreasing loss
base_loss = 0.8 * (1 - progress / 100) + 0.1
loss = base_loss + (0.05 * (0.5 - (current_step % 100) / 100))
val_loss = loss * 1.1
# Update job state
job["progress"] = progress
job["current_epoch"] = min(current_epoch, job["total_epochs"])
job["loss"] = round(loss, 4)
job["val_loss"] = round(val_loss, 4)
job["documents_processed"] = int((progress / 100) * job["total_documents"])
# Update metrics
job["metrics"]["loss_history"].append(round(loss, 4))
job["metrics"]["val_loss_history"].append(round(val_loss, 4))
job["metrics"]["precision"] = round(0.5 + (progress / 200), 3)
job["metrics"]["recall"] = round(0.45 + (progress / 200), 3)
job["metrics"]["f1_score"] = round(0.47 + (progress / 200), 3)
job["metrics"]["accuracy"] = round(0.6 + (progress / 250), 3)
# Keep only last 50 history points
if len(job["metrics"]["loss_history"]) > 50:
job["metrics"]["loss_history"] = job["metrics"]["loss_history"][-50:]
job["metrics"]["val_loss_history"] = job["metrics"]["val_loss_history"][-50:]
# Estimate completion
if progress > 0:
elapsed = (datetime.now() - datetime.fromisoformat(job["started_at"])).total_seconds()
remaining = (elapsed / progress) * (100 - progress)
job["estimated_completion"] = (datetime.now() + timedelta(seconds=remaining)).isoformat()
current_step += 1
await asyncio.sleep(0.5) # Simulate work
# Mark as completed
if job["status"] == TrainingStatus.TRAINING.value:
job["status"] = TrainingStatus.COMPLETED.value
job["progress"] = 100
job["completed_at"] = datetime.now().isoformat()
# Create model version
version_id = str(uuid.uuid4())
_state.model_versions[version_id] = {
"id": version_id,
"job_id": job_id,
"version": f"v{len(_state.model_versions) + 1}.0",
"model_type": job["model_type"],
"created_at": datetime.now().isoformat(),
"metrics": job["metrics"],
"is_active": True,
"size_mb": 245.7,
"bundeslaender": job["config"]["bundeslaender"],
}
_state.active_job_id = None
async def training_metrics_generator(job_id: str, request):
"""
SSE generator for streaming training metrics.
Yields JSON-encoded training status updates every 500ms.
"""
while True:
# Check if client disconnected
if await request.is_disconnected():
break
# Get job status
if job_id not in _state.jobs:
yield f"data: {json.dumps({'error': 'Job not found'})}\n\n"
break
job = _state.jobs[job_id]
# Build metrics response
metrics_data = {
"job_id": job["id"],
"status": job["status"],
"progress": job["progress"],
"current_epoch": job["current_epoch"],
"total_epochs": job["total_epochs"],
"current_step": int(job["progress"] * job["total_epochs"]),
"total_steps": job["total_epochs"] * 100,
"elapsed_time_ms": 0,
"estimated_remaining_ms": 0,
"metrics": {
"loss": job["loss"],
"val_loss": job["val_loss"],
"accuracy": job["metrics"]["accuracy"],
"learning_rate": job["learning_rate"]
},
"history": [
{
"epoch": i + 1,
"step": (i + 1) * 10,
"loss": loss,
"val_loss": job["metrics"]["val_loss_history"][i] if i < len(job["metrics"]["val_loss_history"]) else None,
"learning_rate": job["learning_rate"],
"timestamp": 0
}
for i, loss in enumerate(job["metrics"]["loss_history"][-50:])
]
}
# Calculate elapsed time
if job["started_at"]:
started = datetime.fromisoformat(job["started_at"])
metrics_data["elapsed_time_ms"] = int((datetime.now() - started).total_seconds() * 1000)
# Calculate remaining time
if job["estimated_completion"]:
estimated = datetime.fromisoformat(job["estimated_completion"])
metrics_data["estimated_remaining_ms"] = max(0, int((estimated - datetime.now()).total_seconds() * 1000))
# Send SSE event
yield f"data: {json.dumps(metrics_data)}\n\n"
# Check if job completed
if job["status"] in [TrainingStatus.COMPLETED.value, TrainingStatus.FAILED.value, TrainingStatus.CANCELLED.value]:
break
# Wait before next update
await asyncio.sleep(0.5)
async def batch_ocr_progress_generator(images_count: int, request):
"""
SSE generator for batch OCR progress simulation.
In production, this would integrate with actual OCR processing.
"""
import random
for i in range(images_count):
# Check if client disconnected
if await request.is_disconnected():
break
# Simulate processing time
await asyncio.sleep(random.uniform(0.3, 0.8))
progress_data = {
"type": "progress",
"current": i + 1,
"total": images_count,
"progress_percent": ((i + 1) / images_count) * 100,
"elapsed_ms": (i + 1) * 500,
"estimated_remaining_ms": (images_count - i - 1) * 500,
"result": {
"text": f"Sample recognized text for image {i + 1}",
"confidence": round(random.uniform(0.7, 0.98), 2),
"processing_time_ms": random.randint(200, 600),
"from_cache": random.random() < 0.2
}
}
yield f"data: {json.dumps(progress_data)}\n\n"
# Send completion event
yield f"data: {json.dumps({'type': 'complete', 'total_time_ms': images_count * 500, 'processed_count': images_count})}\n\n"
# Backward-compat shim -- module moved to training/simulation.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("training.simulation")
+4 -261
View File
@@ -1,261 +1,4 @@
"""
TrOCR API - REST endpoints for TrOCR handwriting OCR.
Provides:
- /ocr/trocr - Single image OCR
- /ocr/trocr/batch - Batch image processing
- /ocr/trocr/status - Model status
- /ocr/trocr/cache - Cache statistics
"""
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from typing import List, Optional
import json
import logging
from services.trocr_service import (
run_trocr_ocr_enhanced,
run_trocr_batch,
run_trocr_batch_stream,
get_model_status,
get_cache_stats,
preload_trocr_model,
OCRResult,
BatchOCRResult
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/ocr/trocr", tags=["TrOCR"])
# =============================================================================
# MODELS
# =============================================================================
class TrOCRResponse(BaseModel):
"""Response model for single image OCR."""
text: str = Field(..., description="Extracted text")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
processing_time_ms: int = Field(..., ge=0, description="Processing time in milliseconds")
model: str = Field(..., description="Model used for OCR")
has_lora_adapter: bool = Field(False, description="Whether LoRA adapter was used")
from_cache: bool = Field(False, description="Whether result was from cache")
image_hash: str = Field("", description="SHA256 hash of image (first 16 chars)")
word_count: int = Field(0, description="Number of words detected")
class BatchOCRResponse(BaseModel):
"""Response model for batch OCR."""
results: List[TrOCRResponse] = Field(..., description="Individual OCR results")
total_time_ms: int = Field(..., ge=0, description="Total processing time")
processed_count: int = Field(..., ge=0, description="Number of images processed")
cached_count: int = Field(0, description="Number of results from cache")
error_count: int = Field(0, description="Number of errors")
class ModelStatusResponse(BaseModel):
"""Response model for model status."""
status: str = Field(..., description="Model status: available, not_installed")
is_loaded: bool = Field(..., description="Whether model is loaded in memory")
model_name: Optional[str] = Field(None, description="Name of loaded model")
device: Optional[str] = Field(None, description="Device model is running on")
loaded_at: Optional[str] = Field(None, description="ISO timestamp when model was loaded")
class CacheStatsResponse(BaseModel):
"""Response model for cache statistics."""
size: int = Field(..., ge=0, description="Current cache size")
max_size: int = Field(..., ge=0, description="Maximum cache size")
ttl_seconds: int = Field(..., ge=0, description="Cache TTL in seconds")
# =============================================================================
# ENDPOINTS
# =============================================================================
@router.get("/status", response_model=ModelStatusResponse)
async def get_trocr_status():
"""
Get TrOCR model status.
Returns information about whether the model is loaded and available.
"""
return get_model_status()
@router.get("/cache", response_model=CacheStatsResponse)
async def get_trocr_cache_stats():
"""
Get TrOCR cache statistics.
Returns information about the OCR result cache.
"""
return get_cache_stats()
@router.post("/preload")
async def preload_model(handwritten: bool = Query(True, description="Load handwritten model")):
"""
Preload TrOCR model into memory.
This speeds up the first OCR request by loading the model ahead of time.
"""
success = preload_trocr_model(handwritten=handwritten)
if success:
return {"status": "success", "message": "Model preloaded successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to preload model")
@router.post("", response_model=TrOCRResponse)
async def run_trocr(
file: UploadFile = File(..., description="Image file to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split image into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on a single image.
Supports PNG, JPG, and other common image formats.
"""
# Validate file type
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File must be an image")
try:
image_data = await file.read()
result = await run_trocr_ocr_enhanced(
image_data,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return TrOCRResponse(
text=result.text,
confidence=result.confidence,
processing_time_ms=result.processing_time_ms,
model=result.model,
has_lora_adapter=result.has_lora_adapter,
from_cache=result.from_cache,
image_hash=result.image_hash,
word_count=len(result.text.split()) if result.text else 0
)
except Exception as e:
logger.error(f"TrOCR API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=BatchOCRResponse)
async def run_trocr_batch_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images.
Processes images sequentially and returns all results.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
batch_result = await run_trocr_batch(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
)
return BatchOCRResponse(
results=[
TrOCRResponse(
text=r.text,
confidence=r.confidence,
processing_time_ms=r.processing_time_ms,
model=r.model,
has_lora_adapter=r.has_lora_adapter,
from_cache=r.from_cache,
image_hash=r.image_hash,
word_count=len(r.text.split()) if r.text else 0
)
for r in batch_result.results
],
total_time_ms=batch_result.total_time_ms,
processed_count=batch_result.processed_count,
cached_count=batch_result.cached_count,
error_count=batch_result.error_count
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR batch API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch/stream")
async def run_trocr_batch_stream_endpoint(
files: List[UploadFile] = File(..., description="Image files to process"),
handwritten: bool = Query(True, description="Use handwritten model"),
split_lines: bool = Query(True, description="Split images into lines"),
use_cache: bool = Query(True, description="Use result caching")
):
"""
Run TrOCR on multiple images with Server-Sent Events (SSE) progress updates.
Returns a stream of progress events as images are processed.
"""
if not files:
raise HTTPException(status_code=400, detail="No files provided")
if len(files) > 50:
raise HTTPException(status_code=400, detail="Maximum 50 images per batch")
try:
images = []
for file in files:
if not file.content_type or not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail=f"File {file.filename} is not an image")
images.append(await file.read())
async def event_generator():
async for update in run_trocr_batch_stream(
images,
handwritten=handwritten,
split_lines=split_lines,
use_cache=use_cache
):
yield f"data: {json.dumps(update)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"TrOCR stream API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Backward-compat shim -- module moved to training/trocr_api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("training.trocr_api")
@@ -0,0 +1,6 @@
"""
worksheet package — worksheet editor, NRU generator, cleanup.
Backward-compatible re-exports: consumers can still use
``from worksheet_editor_api import ...`` etc. via the shim files in backend/.
"""
@@ -0,0 +1,491 @@
"""
Worksheet Cleanup API - Handschrift-Entfernung und Layout-Rekonstruktion
Endpoints:
- POST /api/v1/worksheet/detect-handwriting - Erkennt Handschrift und gibt Maske zurueck
- POST /api/v1/worksheet/remove-handwriting - Entfernt Handschrift aus Bild
- POST /api/v1/worksheet/reconstruct - Rekonstruiert Layout als Fabric.js JSON
- POST /api/v1/worksheet/cleanup-pipeline - Vollstaendige Pipeline (Erkennung + Entfernung + Layout)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini.
"""
import io
import base64
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from services.handwriting_detection import (
detect_handwriting,
detect_handwriting_regions,
mask_to_png
)
from services.inpainting_service import (
inpaint_image,
remove_handwriting,
InpaintingMethod,
check_lama_available
)
from services.layout_reconstruction_service import (
reconstruct_layout,
layout_to_fabric_json,
reconstruct_and_clean
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/worksheet", tags=["Worksheet Cleanup"])
# =============================================================================
# Pydantic Models
# =============================================================================
class DetectionResponse(BaseModel):
has_handwriting: bool
confidence: float
handwriting_ratio: float
detection_method: str
mask_base64: Optional[str] = None
class InpaintingResponse(BaseModel):
success: bool
method_used: str
processing_time_ms: float
image_base64: Optional[str] = None
error: Optional[str] = None
class ReconstructionResponse(BaseModel):
success: bool
element_count: int
page_width: int
page_height: int
fabric_json: dict
table_count: int = 0
class PipelineResponse(BaseModel):
success: bool
handwriting_detected: bool
handwriting_removed: bool
layout_reconstructed: bool
cleaned_image_base64: Optional[str] = None
fabric_json: Optional[dict] = None
metadata: dict = {}
class CapabilitiesResponse(BaseModel):
opencv_available: bool = True
lama_available: bool = False
paddleocr_available: bool = False
# =============================================================================
# API Endpoints
# =============================================================================
@router.get("/capabilities")
async def get_capabilities() -> CapabilitiesResponse:
"""
Get available cleanup capabilities on this server.
"""
# Check PaddleOCR
paddleocr_available = False
try:
from hybrid_vocab_extractor import get_paddle_ocr
ocr = get_paddle_ocr()
paddleocr_available = ocr is not None
except Exception:
pass
return CapabilitiesResponse(
opencv_available=True,
lama_available=check_lama_available(),
paddleocr_available=paddleocr_available
)
@router.post("/detect-handwriting")
async def detect_handwriting_endpoint(
image: UploadFile = File(...),
return_mask: bool = Form(default=True),
min_confidence: float = Form(default=0.3)
) -> DetectionResponse:
"""
Detect handwriting in an image.
Args:
image: Input image (PNG, JPG)
return_mask: Whether to return the binary mask as base64
min_confidence: Minimum confidence threshold
Returns:
DetectionResponse with detection results and optional mask
"""
logger.info(f"Handwriting detection request: {image.filename}")
# Validate file type
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files (PNG, JPG) are supported"
)
try:
image_bytes = await image.read()
# Detect handwriting
result = detect_handwriting(image_bytes)
has_handwriting = (
result.confidence >= min_confidence and
result.handwriting_ratio > 0.005
)
response = DetectionResponse(
has_handwriting=has_handwriting,
confidence=result.confidence,
handwriting_ratio=result.handwriting_ratio,
detection_method=result.detection_method
)
if return_mask:
mask_bytes = mask_to_png(result.mask)
response.mask_base64 = base64.b64encode(mask_bytes).decode('utf-8')
logger.info(f"Detection complete: handwriting={has_handwriting}, "
f"confidence={result.confidence:.2f}")
return response
except Exception as e:
logger.error(f"Handwriting detection failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/detect-handwriting/mask")
async def get_handwriting_mask(
image: UploadFile = File(...)
) -> StreamingResponse:
"""
Get handwriting detection mask as PNG image.
Returns binary mask where white (255) = handwriting.
"""
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
result = detect_handwriting(image_bytes)
mask_bytes = mask_to_png(result.mask)
return StreamingResponse(
io.BytesIO(mask_bytes),
media_type="image/png",
headers={
"Content-Disposition": "attachment; filename=handwriting_mask.png"
}
)
except Exception as e:
logger.error(f"Mask generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/remove-handwriting")
async def remove_handwriting_endpoint(
image: UploadFile = File(...),
mask: Optional[UploadFile] = File(default=None),
method: str = Form(default="auto"),
return_base64: bool = Form(default=False)
):
"""
Remove handwriting from an image.
Args:
image: Input image with handwriting
mask: Optional pre-computed mask (if not provided, auto-detected)
method: Inpainting method (auto, opencv_telea, opencv_ns, lama)
return_base64: If True, return image as base64, else as file
Returns:
Cleaned image (as PNG file or base64 in JSON)
"""
logger.info(f"Remove handwriting request: {image.filename}, method={method}")
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
# Get mask if provided
mask_array = None
if mask is not None:
mask_bytes = await mask.read()
from PIL import Image
import numpy as np
mask_img = Image.open(io.BytesIO(mask_bytes))
mask_array = np.array(mask_img)
# Select inpainting method
inpainting_method = InpaintingMethod.AUTO
if method == "opencv_telea":
inpainting_method = InpaintingMethod.OPENCV_TELEA
elif method == "opencv_ns":
inpainting_method = InpaintingMethod.OPENCV_NS
elif method == "lama":
inpainting_method = InpaintingMethod.LAMA
# Remove handwriting
cleaned_bytes, metadata = remove_handwriting(
image_bytes,
mask=mask_array,
method=inpainting_method
)
if return_base64:
return JSONResponse({
"success": True,
"image_base64": base64.b64encode(cleaned_bytes).decode('utf-8'),
"metadata": metadata
})
else:
return StreamingResponse(
io.BytesIO(cleaned_bytes),
media_type="image/png",
headers={
"Content-Disposition": "attachment; filename=cleaned.png",
"X-Method-Used": metadata.get("method_used", "unknown"),
"X-Processing-Time-Ms": str(metadata.get("processing_time_ms", 0))
}
)
except Exception as e:
logger.error(f"Handwriting removal failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/reconstruct")
async def reconstruct_layout_endpoint(
image: UploadFile = File(...),
clean_handwriting: bool = Form(default=True),
detect_tables: bool = Form(default=True)
) -> ReconstructionResponse:
"""
Reconstruct worksheet layout and generate Fabric.js JSON.
Args:
image: Input image (can contain handwriting)
clean_handwriting: Whether to remove handwriting first
detect_tables: Whether to detect table structures
Returns:
ReconstructionResponse with Fabric.js JSON
"""
logger.info(f"Layout reconstruction request: {image.filename}")
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
# Run reconstruction pipeline
if clean_handwriting:
cleaned_bytes, layout = reconstruct_and_clean(image_bytes)
else:
layout = reconstruct_layout(image_bytes, detect_tables=detect_tables)
return ReconstructionResponse(
success=True,
element_count=len(layout.elements),
page_width=layout.page_width,
page_height=layout.page_height,
fabric_json=layout.fabric_json,
table_count=len(layout.table_regions)
)
except Exception as e:
logger.error(f"Layout reconstruction failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cleanup-pipeline")
async def full_cleanup_pipeline(
image: UploadFile = File(...),
remove_hw: bool = Form(default=True, alias="remove_handwriting"),
reconstruct: bool = Form(default=True),
inpainting_method: str = Form(default="auto")
) -> PipelineResponse:
"""
Full cleanup pipeline: detect, remove handwriting, reconstruct layout.
This is the recommended endpoint for processing filled worksheets.
Args:
image: Input image (scan/photo of filled worksheet)
remove_handwriting: Whether to remove detected handwriting
reconstruct: Whether to reconstruct layout as Fabric.js JSON
inpainting_method: Method for inpainting (auto, opencv_telea, opencv_ns, lama)
Returns:
PipelineResponse with cleaned image and Fabric.js JSON
"""
logger.info(f"Full cleanup pipeline: {image.filename}")
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
metadata = {}
# Step 1: Detect handwriting
detection = detect_handwriting(image_bytes)
handwriting_detected = (
detection.confidence >= 0.3 and
detection.handwriting_ratio > 0.005
)
metadata["detection"] = {
"confidence": detection.confidence,
"handwriting_ratio": detection.handwriting_ratio,
"method": detection.detection_method
}
# Step 2: Remove handwriting if requested and detected
cleaned_bytes = image_bytes
handwriting_removed = False
if remove_hw and handwriting_detected:
method = InpaintingMethod.AUTO
if inpainting_method == "opencv_telea":
method = InpaintingMethod.OPENCV_TELEA
elif inpainting_method == "opencv_ns":
method = InpaintingMethod.OPENCV_NS
elif inpainting_method == "lama":
method = InpaintingMethod.LAMA
cleaned_bytes, inpaint_metadata = remove_handwriting(
image_bytes,
mask=detection.mask,
method=method
)
handwriting_removed = inpaint_metadata.get("inpainting_performed", False)
metadata["inpainting"] = inpaint_metadata
# Step 3: Reconstruct layout if requested
fabric_json = None
layout_reconstructed = False
if reconstruct:
layout = reconstruct_layout(cleaned_bytes)
fabric_json = layout.fabric_json
layout_reconstructed = len(layout.elements) > 0
metadata["layout"] = {
"element_count": len(layout.elements),
"table_count": len(layout.table_regions),
"page_width": layout.page_width,
"page_height": layout.page_height
}
# Encode cleaned image as base64
cleaned_base64 = base64.b64encode(cleaned_bytes).decode('utf-8')
logger.info(f"Pipeline complete: detected={handwriting_detected}, "
f"removed={handwriting_removed}, layout={layout_reconstructed}")
return PipelineResponse(
success=True,
handwriting_detected=handwriting_detected,
handwriting_removed=handwriting_removed,
layout_reconstructed=layout_reconstructed,
cleaned_image_base64=cleaned_base64,
fabric_json=fabric_json,
metadata=metadata
)
except Exception as e:
logger.error(f"Cleanup pipeline failed: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/preview-cleanup")
async def preview_cleanup(
image: UploadFile = File(...)
) -> JSONResponse:
"""
Quick preview of cleanup results without full processing.
Returns detection results and estimated processing time.
"""
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
# Quick detection only
result = detect_handwriting_regions(image_bytes)
# Estimate processing time based on image size
from PIL import Image
img = Image.open(io.BytesIO(image_bytes))
pixel_count = img.width * img.height
# Rough estimates
est_detection_ms = 100 + (pixel_count / 1000000) * 200
est_inpainting_ms = 500 + (pixel_count / 1000000) * 1000
est_reconstruction_ms = 200 + (pixel_count / 1000000) * 300
return JSONResponse({
"has_handwriting": result["has_handwriting"],
"confidence": result["confidence"],
"handwriting_ratio": result["handwriting_ratio"],
"image_width": img.width,
"image_height": img.height,
"estimated_times_ms": {
"detection": est_detection_ms,
"inpainting": est_inpainting_ms if result["has_handwriting"] else 0,
"reconstruction": est_reconstruction_ms,
"total": est_detection_ms + (est_inpainting_ms if result["has_handwriting"] else 0) + est_reconstruction_ms
},
"capabilities": {
"lama_available": check_lama_available()
}
})
except Exception as e:
logger.error(f"Preview failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -0,0 +1,485 @@
"""
Worksheet Editor AI — AI image generation and AI worksheet modification.
"""
import io
import json
import base64
import logging
import re
import time
import random
from typing import List, Dict
import httpx
from .editor_models import (
AIImageRequest,
AIImageResponse,
AIImageStyle,
AIModifyRequest,
AIModifyResponse,
OLLAMA_URL,
STYLE_PROMPTS,
)
logger = logging.getLogger(__name__)
# =============================================
# AI IMAGE GENERATION
# =============================================
async def generate_ai_image_logic(request: AIImageRequest) -> AIImageResponse:
"""
Generate an AI image using Ollama with a text-to-image model.
Falls back to a placeholder if Ollama is not available.
"""
from fastapi import HTTPException
try:
# Build enhanced prompt with style
style_modifier = STYLE_PROMPTS.get(request.style, "")
enhanced_prompt = f"{request.prompt}, {style_modifier}"
logger.info(f"Generating AI image: {enhanced_prompt[:100]}...")
# Check if Ollama is available
async with httpx.AsyncClient(timeout=10.0) as check_client:
try:
health_response = await check_client.get(f"{OLLAMA_URL}/api/tags")
if health_response.status_code != 200:
raise HTTPException(status_code=503, detail="Ollama service not available")
except httpx.ConnectError:
logger.warning("Ollama not reachable, returning placeholder")
return _generate_placeholder_image(request, enhanced_prompt)
try:
async with httpx.AsyncClient(timeout=300.0) as client:
tags_response = await client.get(f"{OLLAMA_URL}/api/tags")
available_models = [m.get("name", "") for m in tags_response.json().get("models", [])]
sd_model = None
for model in available_models:
if "stable" in model.lower() or "sd" in model.lower() or "diffusion" in model.lower():
sd_model = model
break
if not sd_model:
logger.warning("No Stable Diffusion model found in Ollama")
return _generate_placeholder_image(request, enhanced_prompt)
logger.info(f"SD model found: {sd_model}, but image generation API not implemented")
return _generate_placeholder_image(request, enhanced_prompt)
except Exception as e:
logger.error(f"Image generation failed: {e}")
return _generate_placeholder_image(request, enhanced_prompt)
except HTTPException:
raise
except Exception as e:
logger.error(f"AI image generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _generate_placeholder_image(request: AIImageRequest, prompt: str) -> AIImageResponse:
"""
Generate a placeholder image when AI generation is not available.
Creates a simple SVG-based placeholder with the prompt text.
"""
from PIL import Image, ImageDraw, ImageFont
width, height = request.width, request.height
style_colors = {
AIImageStyle.REALISTIC: ("#2563eb", "#dbeafe"),
AIImageStyle.CARTOON: ("#f97316", "#ffedd5"),
AIImageStyle.SKETCH: ("#6b7280", "#f3f4f6"),
AIImageStyle.CLIPART: ("#8b5cf6", "#ede9fe"),
AIImageStyle.EDUCATIONAL: ("#059669", "#d1fae5"),
}
fg_color, bg_color = style_colors.get(request.style, ("#6366f1", "#e0e7ff"))
img = Image.new('RGB', (width, height), bg_color)
draw = ImageDraw.Draw(img)
draw.rectangle([5, 5, width-6, height-6], outline=fg_color, width=3)
cx, cy = width // 2, height // 2 - 30
draw.ellipse([cx-40, cy-40, cx+40, cy+40], outline=fg_color, width=3)
draw.line([cx-20, cy-10, cx+20, cy-10], fill=fg_color, width=3)
draw.line([cx, cy-10, cx, cy+20], fill=fg_color, width=3)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
except Exception:
font = ImageFont.load_default()
max_chars = 40
lines = []
words = prompt[:200].split()
current_line = ""
for word in words:
if len(current_line) + len(word) + 1 <= max_chars:
current_line += (" " + word if current_line else word)
else:
if current_line:
lines.append(current_line)
current_line = word
if current_line:
lines.append(current_line)
text_y = cy + 60
for line in lines[:4]:
bbox = draw.textbbox((0, 0), line, font=font)
text_width = bbox[2] - bbox[0]
draw.text((cx - text_width // 2, text_y), line, fill=fg_color, font=font)
text_y += 20
badge_text = "KI-Bild (Platzhalter)"
try:
badge_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10)
except Exception:
badge_font = font
draw.rectangle([10, height-30, 150, height-10], fill=fg_color)
draw.text((15, height-27), badge_text, fill="white", font=badge_font)
buffer = io.BytesIO()
img.save(buffer, format='PNG')
buffer.seek(0)
image_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
return AIImageResponse(
image_base64=image_base64,
prompt_used=prompt,
error="AI image generation not available. Using placeholder."
)
# =============================================
# AI WORKSHEET MODIFICATION
# =============================================
async def modify_worksheet_with_ai_logic(request: AIModifyRequest) -> AIModifyResponse:
"""
Modify a worksheet using AI based on natural language prompt.
"""
try:
logger.info(f"AI modify request: {request.prompt[:100]}...")
try:
canvas_data = json.loads(request.canvas_json)
except json.JSONDecodeError:
return AIModifyResponse(
message="Fehler beim Parsen des Canvas",
error="Invalid canvas JSON"
)
system_prompt = """Du bist ein Assistent fuer die Bearbeitung von Arbeitsblaettern.
Du erhaeltst den aktuellen Zustand eines Canvas im JSON-Format und eine Anweisung des Nutzers.
Deine Aufgabe ist es, die gewuenschten Aenderungen am Canvas vorzunehmen.
Der Canvas verwendet Fabric.js. Hier sind die wichtigsten Objekttypen:
- i-text: Interaktiver Text mit fontFamily, fontSize, fill, left, top
- rect: Rechteck mit left, top, width, height, fill, stroke, strokeWidth
- circle: Kreis mit left, top, radius, fill, stroke, strokeWidth
- line: Linie mit x1, y1, x2, y2, stroke, strokeWidth
Das Canvas ist 794x1123 Pixel (A4 bei 96 DPI).
Antworte NUR mit einem JSON-Objekt in diesem Format:
{
"action": "modify" oder "add" oder "delete" oder "info",
"objects": [...], // Neue/modifizierte Objekte (bei modify/add)
"message": "Kurze Beschreibung der Aenderung"
}
Wenn du Objekte hinzufuegst, generiere eindeutige IDs im Format "obj_<timestamp>_<random>".
"""
user_prompt = f"""Aktueller Canvas-Zustand:
```json
{json.dumps(canvas_data, indent=2)[:5000]}
```
Nutzer-Anweisung: {request.prompt}
Fuehre die Aenderung durch und antworte mit dem JSON-Objekt."""
try:
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{OLLAMA_URL}/api/generate",
json={
"model": request.model,
"prompt": user_prompt,
"system": system_prompt,
"stream": False,
"options": {
"temperature": 0.3,
"num_predict": 4096
}
}
)
if response.status_code != 200:
logger.warning(f"Ollama error: {response.status_code}, trying local fallback")
return _handle_simple_modification(request.prompt, canvas_data)
ai_response = response.json().get("response", "")
except httpx.ConnectError:
logger.warning("Ollama not reachable")
return _handle_simple_modification(request.prompt, canvas_data)
except httpx.TimeoutException:
logger.warning("Ollama timeout, trying local fallback")
return _handle_simple_modification(request.prompt, canvas_data)
try:
json_start = ai_response.find('{')
json_end = ai_response.rfind('}') + 1
if json_start == -1 or json_end <= json_start:
logger.warning(f"No JSON found in AI response: {ai_response[:200]}")
return AIModifyResponse(
message="KI konnte die Anfrage nicht verarbeiten",
error="No JSON in response"
)
ai_json = json.loads(ai_response[json_start:json_end])
action = ai_json.get("action", "info")
message = ai_json.get("message", "Aenderungen angewendet")
new_objects = ai_json.get("objects", [])
if action == "info":
return AIModifyResponse(message=message)
if action == "add" and new_objects:
existing_objects = canvas_data.get("objects", [])
existing_objects.extend(new_objects)
canvas_data["objects"] = existing_objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
if action == "modify" and new_objects:
existing_objects = canvas_data.get("objects", [])
new_ids = {obj.get("id") for obj in new_objects if obj.get("id")}
kept_objects = [obj for obj in existing_objects if obj.get("id") not in new_ids]
kept_objects.extend(new_objects)
canvas_data["objects"] = kept_objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
if action == "delete":
delete_ids = ai_json.get("delete_ids", [])
if delete_ids:
existing_objects = canvas_data.get("objects", [])
canvas_data["objects"] = [obj for obj in existing_objects if obj.get("id") not in delete_ids]
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
return AIModifyResponse(message=message)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse AI JSON: {e}")
return AIModifyResponse(
message="Fehler beim Verarbeiten der KI-Antwort",
error=str(e)
)
except Exception as e:
logger.error(f"AI modify error: {e}")
return AIModifyResponse(
message="Ein unerwarteter Fehler ist aufgetreten",
error=str(e)
)
def _handle_simple_modification(prompt: str, canvas_data: dict) -> AIModifyResponse:
"""
Handle simple modifications locally when Ollama is not available.
Supports basic commands like adding headings, lines, etc.
"""
prompt_lower = prompt.lower()
objects = canvas_data.get("objects", [])
def generate_id():
return f"obj_{int(time.time()*1000)}_{random.randint(1000, 9999)}"
# Add heading
if "ueberschrift" in prompt_lower or "titel" in prompt_lower or "heading" in prompt_lower:
text_match = re.search(r'"([^"]+)"', prompt)
text = text_match.group(1) if text_match else "Ueberschrift"
new_text = {
"type": "i-text", "id": generate_id(), "text": text,
"left": 397, "top": 50, "originX": "center",
"fontFamily": "Arial", "fontSize": 28, "fontWeight": "bold", "fill": "#000000"
}
objects.append(new_text)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"Ueberschrift '{text}' hinzugefuegt"
)
# Add lines for writing
if "linie" in prompt_lower or "line" in prompt_lower or "schreib" in prompt_lower:
num_match = re.search(r'(\d+)', prompt)
num_lines = int(num_match.group(1)) if num_match else 5
num_lines = min(num_lines, 20)
start_y = 150
line_spacing = 40
for i in range(num_lines):
new_line = {
"type": "line", "id": generate_id(),
"x1": 60, "y1": start_y + i * line_spacing,
"x2": 734, "y2": start_y + i * line_spacing,
"stroke": "#cccccc", "strokeWidth": 1
}
objects.append(new_line)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{num_lines} Schreiblinien hinzugefuegt"
)
# Make text bigger
if "groesser" in prompt_lower or "bigger" in prompt_lower or "larger" in prompt_lower:
modified = 0
for obj in objects:
if obj.get("type") in ["i-text", "text", "textbox"]:
current_size = obj.get("fontSize", 16)
obj["fontSize"] = int(current_size * 1.25)
modified += 1
canvas_data["objects"] = objects
if modified > 0:
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{modified} Texte vergroessert"
)
# Center elements
if "zentrier" in prompt_lower or "center" in prompt_lower or "mitte" in prompt_lower:
center_x = 397
for obj in objects:
if not obj.get("isGrid"):
obj["left"] = center_x
obj["originX"] = "center"
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message="Elemente zentriert"
)
# Add numbering
if "nummer" in prompt_lower or "nummerier" in prompt_lower or "1-10" in prompt_lower:
range_match = re.search(r'(\d+)\s*[-bis]+\s*(\d+)', prompt)
if range_match:
start, end = int(range_match.group(1)), int(range_match.group(2))
else:
start, end = 1, 10
y = 100
for i in range(start, min(end + 1, start + 20)):
new_text = {
"type": "i-text", "id": generate_id(), "text": f"{i}.",
"left": 40, "top": y, "fontFamily": "Arial", "fontSize": 14, "fill": "#000000"
}
objects.append(new_text)
y += 35
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"Nummerierung {start}-{end} hinzugefuegt"
)
# Add rectangle/box
if "rechteck" in prompt_lower or "box" in prompt_lower or "kasten" in prompt_lower:
new_rect = {
"type": "rect", "id": generate_id(),
"left": 100, "top": 200, "width": 200, "height": 100,
"fill": "transparent", "stroke": "#000000", "strokeWidth": 2
}
objects.append(new_rect)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message="Rechteck hinzugefuegt"
)
# Add grid/raster
if "raster" in prompt_lower or "grid" in prompt_lower or "tabelle" in prompt_lower:
dim_match = re.search(r'(\d+)\s*[x/\u00d7\*mal by]\s*(\d+)', prompt_lower)
if dim_match:
cols = int(dim_match.group(1))
rows = int(dim_match.group(2))
else:
nums = re.findall(r'(\d+)', prompt)
if len(nums) >= 2:
cols, rows = int(nums[0]), int(nums[1])
else:
cols, rows = 3, 4
cols = min(max(1, cols), 10)
rows = min(max(1, rows), 15)
canvas_width = 794
canvas_height = 1123
margin = 60
available_width = canvas_width - 2 * margin
available_height = canvas_height - 2 * margin - 80
cell_width = available_width / cols
cell_height = min(available_height / rows, 80)
start_x = margin
start_y = 120
grid_objects = []
for r in range(rows + 1):
y = start_y + r * cell_height
grid_objects.append({
"type": "line", "id": generate_id(),
"x1": start_x, "y1": y,
"x2": start_x + cols * cell_width, "y2": y,
"stroke": "#666666", "strokeWidth": 1, "isGrid": True
})
for c in range(cols + 1):
x = start_x + c * cell_width
grid_objects.append({
"type": "line", "id": generate_id(),
"x1": x, "y1": start_y,
"x2": x, "y2": start_y + rows * cell_height,
"stroke": "#666666", "strokeWidth": 1, "isGrid": True
})
objects.extend(grid_objects)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{cols}x{rows} Raster hinzugefuegt ({cols} Spalten, {rows} Zeilen)"
)
# Default: Ollama needed
return AIModifyResponse(
message="Diese Aenderung erfordert den KI-Service. Bitte stellen Sie sicher, dass Ollama laeuft.",
error="Complex modification requires Ollama"
)
@@ -0,0 +1,388 @@
"""
Worksheet Editor API - Backend Endpoints for Visual Worksheet Editor
Provides endpoints for:
- AI Image generation via Ollama/Stable Diffusion
- Worksheet Save/Load
- PDF Export
Split modules:
- worksheet_editor_models: Enums, Pydantic models, configuration
- worksheet_editor_ai: AI image generation and AI worksheet modification
- worksheet_editor_reconstruct: Document reconstruction from vocab sessions
"""
import os
import io
import json
import logging
from datetime import datetime, timezone
import uuid
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
import httpx
# Re-export everything from sub-modules for backward compatibility
from .editor_models import ( # noqa: F401
AIImageStyle,
WorksheetStatus,
AIImageRequest,
AIImageResponse,
PageData,
PageFormat,
WorksheetSaveRequest,
WorksheetResponse,
AIModifyRequest,
AIModifyResponse,
ReconstructRequest,
ReconstructResponse,
worksheets_db,
OLLAMA_URL,
SD_MODEL,
WORKSHEET_STORAGE_DIR,
STYLE_PROMPTS,
REPORTLAB_AVAILABLE,
)
from .editor_ai import ( # noqa: F401
generate_ai_image_logic,
_generate_placeholder_image,
modify_worksheet_with_ai_logic,
_handle_simple_modification,
)
from .editor_reconstruct import ( # noqa: F401
reconstruct_document_logic,
_detect_image_regions,
)
logger = logging.getLogger(__name__)
# =============================================
# ROUTER
# =============================================
router = APIRouter(prefix="/api/v1/worksheet", tags=["Worksheet Editor"])
# =============================================
# AI IMAGE GENERATION
# =============================================
@router.post("/ai-image", response_model=AIImageResponse)
async def generate_ai_image(request: AIImageRequest):
"""
Generate an AI image using Ollama with a text-to-image model.
Supported models:
- stable-diffusion (via Ollama)
- sd3.5-medium
- llava (for image understanding, not generation)
Falls back to a placeholder if Ollama is not available.
"""
return await generate_ai_image_logic(request)
# =============================================
# WORKSHEET SAVE/LOAD
# =============================================
@router.post("/save", response_model=WorksheetResponse)
async def save_worksheet(request: WorksheetSaveRequest):
"""
Save a worksheet document.
- If id is provided, updates existing worksheet
- If id is not provided, creates new worksheet
"""
try:
now = datetime.now(timezone.utc).isoformat()
worksheet_id = request.id or f"ws_{uuid.uuid4().hex[:12]}"
worksheet = {
"id": worksheet_id,
"title": request.title,
"description": request.description,
"pages": [p.dict() for p in request.pages],
"pageFormat": (request.pageFormat or PageFormat()).dict(),
"createdAt": worksheets_db.get(worksheet_id, {}).get("createdAt", now),
"updatedAt": now
}
worksheets_db[worksheet_id] = worksheet
filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json")
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(worksheet, f, ensure_ascii=False, indent=2)
logger.info(f"Saved worksheet: {worksheet_id}")
return WorksheetResponse(**worksheet)
except Exception as e:
logger.error(f"Failed to save worksheet: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save: {str(e)}")
@router.get("/{worksheet_id}", response_model=WorksheetResponse)
async def get_worksheet(worksheet_id: str):
"""Load a worksheet document by ID."""
try:
if worksheet_id in worksheets_db:
return WorksheetResponse(**worksheets_db[worksheet_id])
filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json")
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
worksheet = json.load(f)
worksheets_db[worksheet_id] = worksheet
return WorksheetResponse(**worksheet)
raise HTTPException(status_code=404, detail="Worksheet not found")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to load worksheet {worksheet_id}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to load: {str(e)}")
@router.get("/list/all")
async def list_worksheets():
"""List all available worksheets."""
try:
worksheets = []
for filename in os.listdir(WORKSHEET_STORAGE_DIR):
if filename.endswith('.json'):
filepath = os.path.join(WORKSHEET_STORAGE_DIR, filename)
try:
with open(filepath, 'r', encoding='utf-8') as f:
worksheet = json.load(f)
worksheets.append({
"id": worksheet.get("id"),
"title": worksheet.get("title"),
"description": worksheet.get("description"),
"pageCount": len(worksheet.get("pages", [])),
"updatedAt": worksheet.get("updatedAt"),
"createdAt": worksheet.get("createdAt")
})
except Exception as e:
logger.warning(f"Failed to load {filename}: {e}")
worksheets.sort(key=lambda x: x.get("updatedAt", ""), reverse=True)
return {"worksheets": worksheets, "total": len(worksheets)}
except Exception as e:
logger.error(f"Failed to list worksheets: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{worksheet_id}")
async def delete_worksheet(worksheet_id: str):
"""Delete a worksheet document."""
try:
if worksheet_id in worksheets_db:
del worksheets_db[worksheet_id]
filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json")
if os.path.exists(filepath):
os.remove(filepath)
logger.info(f"Deleted worksheet: {worksheet_id}")
return {"status": "deleted", "id": worksheet_id}
raise HTTPException(status_code=404, detail="Worksheet not found")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete worksheet {worksheet_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
# =============================================
# PDF EXPORT
# =============================================
@router.post("/{worksheet_id}/export-pdf")
async def export_worksheet_pdf(worksheet_id: str):
"""
Export worksheet as PDF.
Note: This creates a basic PDF. For full canvas rendering,
the frontend should use pdf-lib with canvas.toDataURL().
"""
if not REPORTLAB_AVAILABLE:
raise HTTPException(status_code=501, detail="PDF export not available (reportlab not installed)")
try:
from reportlab.lib.pagesizes import A4
from reportlab.pdfgen import canvas
worksheet = worksheets_db.get(worksheet_id)
if not worksheet:
filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json")
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
worksheet = json.load(f)
else:
raise HTTPException(status_code=404, detail="Worksheet not found")
buffer = io.BytesIO()
c = canvas.Canvas(buffer, pagesize=A4)
page_width, page_height = A4
for page_data in worksheet.get("pages", []):
if page_data.get("index", 0) == 0:
c.setFont("Helvetica-Bold", 18)
c.drawString(50, page_height - 50, worksheet.get("title", "Arbeitsblatt"))
c.setFont("Helvetica", 10)
c.drawString(50, page_height - 70, f"Erstellt: {worksheet.get('createdAt', '')[:10]}")
canvas_json_str = page_data.get("canvasJSON", "{}")
if canvas_json_str:
try:
canvas_data = json.loads(canvas_json_str)
objects = canvas_data.get("objects", [])
for obj in objects:
obj_type = obj.get("type", "")
if obj_type in ["text", "i-text", "textbox"]:
text = obj.get("text", "")
left = obj.get("left", 50)
top = obj.get("top", 100)
font_size = obj.get("fontSize", 12)
pdf_x = left * 0.75
pdf_y = page_height - (top * 0.75)
c.setFont("Helvetica", min(font_size, 24))
c.drawString(pdf_x, pdf_y, text[:100])
elif obj_type == "rect":
left = obj.get("left", 0) * 0.75
top = obj.get("top", 0) * 0.75
width = obj.get("width", 50) * 0.75
height = obj.get("height", 30) * 0.75
c.rect(left, page_height - top - height, width, height)
elif obj_type == "circle":
left = obj.get("left", 0) * 0.75
top = obj.get("top", 0) * 0.75
radius = obj.get("radius", 25) * 0.75
c.circle(left + radius, page_height - top - radius, radius)
except json.JSONDecodeError:
pass
c.showPage()
c.save()
buffer.seek(0)
filename = f"{worksheet.get('title', 'worksheet').replace(' ', '_')}.pdf"
return StreamingResponse(
buffer,
media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"PDF export failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
# =============================================
# AI WORKSHEET MODIFICATION
# =============================================
@router.post("/ai-modify", response_model=AIModifyResponse)
async def modify_worksheet_with_ai(request: AIModifyRequest):
"""
Modify a worksheet using AI based on natural language prompt.
Uses Ollama with qwen2.5vl:32b to understand the canvas state
and generate modifications based on the user's request.
"""
return await modify_worksheet_with_ai_logic(request)
# =============================================
# HEALTH CHECK
# =============================================
@router.get("/health/check")
async def health_check():
"""Check worksheet editor API health and dependencies."""
status = {
"status": "healthy",
"ollama": False,
"storage": os.path.exists(WORKSHEET_STORAGE_DIR),
"reportlab": REPORTLAB_AVAILABLE,
"worksheets_count": len(worksheets_db)
}
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{OLLAMA_URL}/api/tags")
status["ollama"] = response.status_code == 200
except Exception:
pass
return status
# =============================================
# DOCUMENT RECONSTRUCTION FROM VOCAB SESSION
# =============================================
@router.post("/reconstruct-from-session", response_model=ReconstructResponse)
async def reconstruct_document_from_session(request: ReconstructRequest):
"""
Reconstruct a document from a vocab session into Fabric.js canvas format.
Returns canvas JSON ready to load into the worksheet editor.
"""
try:
return await reconstruct_document_logic(request)
except HTTPException:
raise
except Exception as e:
logger.error(f"Document reconstruction failed: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.get("/sessions/available")
async def get_available_sessions():
"""Get list of available vocab sessions that can be reconstructed."""
try:
from vocab_worksheet_api import _sessions
available = []
for session_id, session in _sessions.items():
if session.get("pdf_data"):
available.append({
"id": session_id,
"name": session.get("name", "Unnamed"),
"description": session.get("description"),
"vocabulary_count": len(session.get("vocabulary", [])),
"page_count": session.get("pdf_page_count", 1),
"status": session.get("status", "unknown"),
"created_at": session.get("created_at", "").isoformat() if session.get("created_at") else None
})
return {"sessions": available, "total": len(available)}
except Exception as e:
logger.error(f"Failed to list sessions: {e}")
raise HTTPException(status_code=500, detail=str(e))
@@ -0,0 +1,133 @@
"""
Worksheet Editor Models — Enums, Pydantic models, and configuration.
"""
import os
import logging
from typing import Optional, List, Dict
from enum import Enum
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# =============================================
# CONFIGURATION
# =============================================
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
SD_MODEL = os.getenv("SD_MODEL", "stable-diffusion") # or specific SD model
WORKSHEET_STORAGE_DIR = os.getenv("WORKSHEET_STORAGE_DIR",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "worksheet-storage"))
# Ensure storage directory exists
os.makedirs(WORKSHEET_STORAGE_DIR, exist_ok=True)
# =============================================
# ENUMS & MODELS
# =============================================
class AIImageStyle(str, Enum):
REALISTIC = "realistic"
CARTOON = "cartoon"
SKETCH = "sketch"
CLIPART = "clipart"
EDUCATIONAL = "educational"
class WorksheetStatus(str, Enum):
DRAFT = "draft"
PUBLISHED = "published"
ARCHIVED = "archived"
# Style prompt modifiers
STYLE_PROMPTS = {
AIImageStyle.REALISTIC: "photorealistic, high detail, professional photography",
AIImageStyle.CARTOON: "cartoon style, colorful, child-friendly, simple shapes",
AIImageStyle.SKETCH: "pencil sketch, hand-drawn, black and white, artistic",
AIImageStyle.CLIPART: "clipart style, flat design, simple, vector-like",
AIImageStyle.EDUCATIONAL: "educational illustration, clear, informative, textbook style"
}
# =============================================
# REQUEST/RESPONSE MODELS
# =============================================
class AIImageRequest(BaseModel):
prompt: str = Field(..., min_length=3, max_length=500)
style: AIImageStyle = AIImageStyle.EDUCATIONAL
width: int = Field(512, ge=256, le=1024)
height: int = Field(512, ge=256, le=1024)
class AIImageResponse(BaseModel):
image_base64: str
prompt_used: str
error: Optional[str] = None
class PageData(BaseModel):
id: str
index: int
canvasJSON: str
class PageFormat(BaseModel):
width: float = 210
height: float = 297
orientation: str = "portrait"
margins: Dict[str, float] = {"top": 15, "right": 15, "bottom": 15, "left": 15}
class WorksheetSaveRequest(BaseModel):
id: Optional[str] = None
title: str
description: Optional[str] = None
pages: List[PageData]
pageFormat: Optional[PageFormat] = None
class WorksheetResponse(BaseModel):
id: str
title: str
description: Optional[str]
pages: List[PageData]
pageFormat: PageFormat
createdAt: str
updatedAt: str
class AIModifyRequest(BaseModel):
prompt: str = Field(..., min_length=3, max_length=1000)
canvas_json: str
model: str = "qwen2.5vl:32b"
class AIModifyResponse(BaseModel):
modified_canvas_json: Optional[str] = None
message: str
error: Optional[str] = None
class ReconstructRequest(BaseModel):
session_id: str
page_number: int = 1
include_images: bool = True
regenerate_graphics: bool = False
class ReconstructResponse(BaseModel):
canvas_json: str
page_width: int
page_height: int
elements_count: int
vocabulary_matched: int
message: str
error: Optional[str] = None
# =============================================
# IN-MEMORY STORAGE (Development)
# =============================================
worksheets_db: Dict[str, Dict] = {}
# PDF Generation availability
try:
from reportlab.lib import colors # noqa: F401
from reportlab.lib.pagesizes import A4 # noqa: F401
from reportlab.lib.units import mm # noqa: F401
from reportlab.pdfgen import canvas # noqa: F401
from reportlab.lib.styles import getSampleStyleSheet # noqa: F401
REPORTLAB_AVAILABLE = True
except ImportError:
REPORTLAB_AVAILABLE = False
@@ -0,0 +1,255 @@
"""
Worksheet Editor Reconstruct — Document reconstruction from vocab sessions.
"""
import io
import uuid
import base64
import logging
from typing import List, Dict
import numpy as np
from .editor_models import (
ReconstructRequest,
ReconstructResponse,
)
logger = logging.getLogger(__name__)
async def reconstruct_document_logic(request: ReconstructRequest) -> ReconstructResponse:
"""
Reconstruct a document from a vocab session into Fabric.js canvas format.
This function:
1. Loads the original PDF from the vocab session
2. Runs OCR with position tracking
3. Creates Fabric.js canvas JSON with positioned elements
4. Maps extracted vocabulary to their positions
Returns ReconstructResponse ready to send to the client.
"""
from fastapi import HTTPException
from vocab_worksheet_api import _sessions, convert_pdf_page_to_image
# Check if session exists
if request.session_id not in _sessions:
raise HTTPException(status_code=404, detail=f"Session {request.session_id} not found")
session = _sessions[request.session_id]
if not session.get("pdf_data"):
raise HTTPException(status_code=400, detail="Session has no PDF data")
pdf_data = session["pdf_data"]
page_count = session.get("pdf_page_count", 1)
if request.page_number < 1 or request.page_number > page_count:
raise HTTPException(
status_code=400,
detail=f"Page {request.page_number} not found. PDF has {page_count} pages."
)
vocabulary = session.get("vocabulary", [])
page_vocab = [v for v in vocabulary if v.get("source_page") == request.page_number]
logger.info(f"Reconstructing page {request.page_number} from session {request.session_id}")
logger.info(f"Found {len(page_vocab)} vocabulary items for this page")
image_bytes = await convert_pdf_page_to_image(pdf_data, request.page_number)
if not image_bytes:
raise HTTPException(status_code=500, detail="Failed to convert PDF page to image")
from PIL import Image
img = Image.open(io.BytesIO(image_bytes))
img_width, img_height = img.size
from hybrid_vocab_extractor import run_paddle_ocr
ocr_regions, raw_text = run_paddle_ocr(image_bytes)
logger.info(f"OCR found {len(ocr_regions)} text regions")
A4_WIDTH = 794
A4_HEIGHT = 1123
scale_x = A4_WIDTH / img_width
scale_y = A4_HEIGHT / img_height
fabric_objects = []
# 1. Add white background
fabric_objects.append({
"type": "rect", "left": 0, "top": 0,
"width": A4_WIDTH, "height": A4_HEIGHT,
"fill": "#ffffff", "selectable": False,
"evented": False, "isBackground": True
})
# 2. Group OCR regions by Y-coordinate to detect rows
sorted_regions = sorted(ocr_regions, key=lambda r: (r.y1, r.x1))
# 3. Detect headers (larger text at top)
headers = []
for region in sorted_regions:
height = region.y2 - region.y1
if region.y1 < img_height * 0.15 and height > 30:
headers.append(region)
# 4. Create text objects for each region
vocab_matched = 0
for region in sorted_regions:
left = int(region.x1 * scale_x)
top = int(region.y1 * scale_y)
is_header = region in headers
region_height = region.y2 - region.y1
base_font_size = max(10, min(32, int(region_height * scale_y * 0.8)))
if is_header:
base_font_size = max(base_font_size, 24)
is_vocab = False
vocab_match = None
for v in page_vocab:
if v.get("english", "").lower() in region.text.lower() or \
v.get("german", "").lower() in region.text.lower():
is_vocab = True
vocab_match = v
vocab_matched += 1
break
text_obj = {
"type": "i-text",
"id": f"text_{uuid.uuid4().hex[:8]}",
"left": left, "top": top,
"text": region.text,
"fontFamily": "Arial",
"fontSize": base_font_size,
"fontWeight": "bold" if is_header else "normal",
"fill": "#000000",
"originX": "left", "originY": "top",
}
if is_vocab and vocab_match:
text_obj["isVocabulary"] = True
text_obj["vocabularyId"] = vocab_match.get("id")
text_obj["english"] = vocab_match.get("english")
text_obj["german"] = vocab_match.get("german")
fabric_objects.append(text_obj)
# 5. If include_images, detect and extract image regions
if request.include_images:
image_regions = await _detect_image_regions(image_bytes, ocr_regions, img_width, img_height)
for i, img_region in enumerate(image_regions):
img_x1 = int(img_region["x1"])
img_y1 = int(img_region["y1"])
img_x2 = int(img_region["x2"])
img_y2 = int(img_region["y2"])
cropped = img.crop((img_x1, img_y1, img_x2, img_y2))
buffer = io.BytesIO()
cropped.save(buffer, format='PNG')
buffer.seek(0)
img_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
fabric_objects.append({
"type": "image",
"id": f"img_{uuid.uuid4().hex[:8]}",
"left": int(img_x1 * scale_x),
"top": int(img_y1 * scale_y),
"width": int((img_x2 - img_x1) * scale_x),
"height": int((img_y2 - img_y1) * scale_y),
"src": img_base64,
"scaleX": 1, "scaleY": 1,
})
import json
canvas_data = {
"version": "6.0.0",
"objects": fabric_objects,
"background": "#ffffff"
}
return ReconstructResponse(
canvas_json=json.dumps(canvas_data),
page_width=A4_WIDTH,
page_height=A4_HEIGHT,
elements_count=len(fabric_objects),
vocabulary_matched=vocab_matched,
message=f"Reconstructed page {request.page_number} with {len(fabric_objects)} elements, "
f"{vocab_matched} vocabulary items matched"
)
async def _detect_image_regions(
image_bytes: bytes,
ocr_regions: list,
img_width: int,
img_height: int
) -> List[Dict]:
"""
Detect image/graphic regions in the document.
Uses a simple approach:
1. Find large gaps between text regions (potential image areas)
2. Use edge detection to find bounded regions
3. Filter out text areas
"""
from PIL import Image
import cv2
try:
img = Image.open(io.BytesIO(image_bytes))
img_array = np.array(img.convert('L'))
text_mask = np.ones_like(img_array, dtype=bool)
for region in ocr_regions:
x1 = max(0, region.x1 - 5)
y1 = max(0, region.y1 - 5)
x2 = min(img_width, region.x2 + 5)
y2 = min(img_height, region.y2 + 5)
text_mask[y1:y2, x1:x2] = False
image_regions = []
edges = cv2.Canny(img_array, 50, 150)
edges[~text_mask] = 0
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if w > 50 and h > 50:
if w < img_width * 0.9 and h < img_height * 0.9:
region_content = img_array[y:y+h, x:x+w]
variance = np.var(region_content)
if variance > 500:
image_regions.append({
"x1": x, "y1": y,
"x2": x + w, "y2": y + h
})
filtered_regions = []
for region in sorted(image_regions, key=lambda r: (r["x2"]-r["x1"])*(r["y2"]-r["y1"]), reverse=True):
overlaps = False
for existing in filtered_regions:
if not (region["x2"] < existing["x1"] or region["x1"] > existing["x2"] or
region["y2"] < existing["y1"] or region["y1"] > existing["y2"]):
overlaps = True
break
if not overlaps:
filtered_regions.append(region)
logger.info(f"Detected {len(filtered_regions)} image regions")
return filtered_regions[:10]
except Exception as e:
logger.warning(f"Image region detection failed: {e}")
return []
@@ -0,0 +1,26 @@
"""
NRU Worksheet Generator — barrel re-export.
All implementation split into:
nru_worksheet_models — data classes, entry separation
nru_worksheet_html — HTML generation
nru_worksheet_pdf — PDF generation
Per scanned page, we generate 2 worksheet pages.
"""
# Models
from .nru_models import ( # noqa: F401
VocabEntry,
SentenceEntry,
separate_vocab_and_sentences,
)
# HTML generation
from .nru_html import ( # noqa: F401
generate_nru_html,
generate_nru_worksheet_html,
)
# PDF generation
from .nru_pdf import generate_nru_pdf # noqa: F401
@@ -0,0 +1,466 @@
"""
NRU Worksheet HTML — HTML generation for vocabulary worksheets.
Extracted from nru_worksheet_generator.py for modularity.
"""
import logging
from typing import List, Dict
from .nru_models import VocabEntry, SentenceEntry, separate_vocab_and_sentences
logger = logging.getLogger(__name__)
def generate_nru_html(
vocab_list: List[VocabEntry],
sentence_list: List[SentenceEntry],
page_number: int,
title: str = "Vokabeltest",
show_solutions: bool = False,
line_height_px: int = 28
) -> str:
"""
Generate HTML for NRU-format worksheet.
Returns HTML for 2 pages:
- Page 1: Vocabulary table (3 columns)
- Page 2: Sentence practice (full width)
"""
# Filter by page
page_vocab = [v for v in vocab_list if v.source_page == page_number]
page_sentences = [s for s in sentence_list if s.source_page == page_number]
html = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
@page {{
size: A4;
margin: 1.5cm 2cm;
}}
* {{
box-sizing: border-box;
}}
body {{
font-family: Arial, Helvetica, sans-serif;
font-size: 12pt;
line-height: 1.4;
margin: 0;
padding: 0;
}}
.page {{
page-break-after: always;
min-height: 100%;
}}
.page:last-child {{
page-break-after: avoid;
}}
h1 {{
font-size: 16pt;
margin: 0 0 8px 0;
text-align: center;
}}
.header {{
margin-bottom: 15px;
}}
.name-line {{
font-size: 11pt;
margin-bottom: 10px;
}}
/* Vocabulary Table - 3 columns */
.vocab-table {{
width: 100%;
border-collapse: collapse;
table-layout: fixed;
}}
.vocab-table th {{
background: #f0f0f0;
border: 1px solid #333;
padding: 6px 8px;
font-weight: bold;
font-size: 11pt;
text-align: left;
}}
.vocab-table td {{
border: 1px solid #333;
padding: 4px 8px;
height: {line_height_px}px;
vertical-align: middle;
}}
.vocab-table .col-english {{ width: 35%; }}
.vocab-table .col-german {{ width: 35%; }}
.vocab-table .col-correction {{ width: 30%; }}
.vocab-answer {{
color: #0066cc;
font-style: italic;
}}
/* Sentence Table - full width */
.sentence-table {{
width: 100%;
border-collapse: collapse;
margin-bottom: 15px;
}}
.sentence-table td {{
border: 1px solid #333;
padding: 6px 10px;
}}
.sentence-header {{
background: #f5f5f5;
font-weight: normal;
min-height: 30px;
}}
.sentence-line {{
height: {line_height_px + 4}px;
}}
.sentence-answer {{
color: #0066cc;
font-style: italic;
font-size: 11pt;
}}
.page-info {{
font-size: 9pt;
color: #666;
text-align: right;
margin-top: 10px;
}}
</style>
</head>
<body>
"""
# ========== PAGE 1: VOCABULARY TABLE ==========
if page_vocab:
html += f"""
<div class="page">
<div class="header">
<h1>{title} - Vokabeln (Seite {page_number})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
<table class="vocab-table">
<thead>
<tr>
<th class="col-english">Englisch</th>
<th class="col-german">Deutsch</th>
<th class="col-correction">Korrektur</th>
</tr>
</thead>
<tbody>
"""
for v in page_vocab:
if show_solutions:
html += f"""
<tr>
<td>{v.english}</td>
<td class="vocab-answer">{v.german}</td>
<td></td>
</tr>
"""
else:
html += f"""
<tr>
<td>{v.english}</td>
<td></td>
<td></td>
</tr>
"""
html += """
</tbody>
</table>
<div class="page-info">Vokabeln aus Unit</div>
</div>
"""
# ========== PAGE 2: SENTENCE PRACTICE ==========
if page_sentences:
html += f"""
<div class="page">
<div class="header">
<h1>{title} - Lernsaetze (Seite {page_number})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
"""
for s in page_sentences:
html += f"""
<table class="sentence-table">
<tr>
<td class="sentence-header">{s.german}</td>
</tr>
"""
if show_solutions:
html += f"""
<tr>
<td class="sentence-line sentence-answer">{s.english}</td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
else:
html += """
<tr>
<td class="sentence-line"></td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
html += """
</table>
"""
html += """
<div class="page-info">Lernsaetze aus Unit</div>
</div>
"""
html += """
</body>
</html>
"""
return html
def generate_nru_worksheet_html(
entries: List[Dict],
title: str = "Vokabeltest",
show_solutions: bool = False,
specific_pages: List[int] = None
) -> str:
"""
Generate complete NRU worksheet HTML for all pages.
Args:
entries: List of vocabulary entries with source_page
title: Worksheet title
show_solutions: Whether to show answers
specific_pages: List of specific page numbers to include (1-indexed)
Returns:
Complete HTML document
"""
# Separate into vocab and sentences
vocab_list, sentence_list = separate_vocab_and_sentences(entries)
# Get unique page numbers
all_pages = set()
for v in vocab_list:
all_pages.add(v.source_page)
for s in sentence_list:
all_pages.add(s.source_page)
# Filter to specific pages if requested
if specific_pages:
all_pages = all_pages.intersection(set(specific_pages))
pages_sorted = sorted(all_pages)
logger.info(f"Generating NRU worksheet for pages {pages_sorted}")
logger.info(f"Total vocab: {len(vocab_list)}, Total sentences: {len(sentence_list)}")
# Generate HTML for each page
combined_html = """<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<style>
@page {
size: A4;
margin: 1.5cm 2cm;
}
* {
box-sizing: border-box;
}
body {
font-family: Arial, Helvetica, sans-serif;
font-size: 12pt;
line-height: 1.4;
margin: 0;
padding: 0;
}
.page {
page-break-after: always;
min-height: 100%;
}
.page:last-child {
page-break-after: avoid;
}
h1 {
font-size: 16pt;
margin: 0 0 8px 0;
text-align: center;
}
.header {
margin-bottom: 15px;
}
.name-line {
font-size: 11pt;
margin-bottom: 10px;
}
/* Vocabulary Table - 3 columns */
.vocab-table {
width: 100%;
border-collapse: collapse;
table-layout: fixed;
}
.vocab-table th {
background: #f0f0f0;
border: 1px solid #333;
padding: 6px 8px;
font-weight: bold;
font-size: 11pt;
text-align: left;
}
.vocab-table td {
border: 1px solid #333;
padding: 4px 8px;
height: 28px;
vertical-align: middle;
}
.vocab-table .col-english { width: 35%; }
.vocab-table .col-german { width: 35%; }
.vocab-table .col-correction { width: 30%; }
.vocab-answer {
color: #0066cc;
font-style: italic;
}
/* Sentence Table - full width */
.sentence-table {
width: 100%;
border-collapse: collapse;
margin-bottom: 15px;
}
.sentence-table td {
border: 1px solid #333;
padding: 6px 10px;
}
.sentence-header {
background: #f5f5f5;
font-weight: normal;
min-height: 30px;
}
.sentence-line {
height: 32px;
}
.sentence-answer {
color: #0066cc;
font-style: italic;
font-size: 11pt;
}
.page-info {
font-size: 9pt;
color: #666;
text-align: right;
margin-top: 10px;
}
</style>
</head>
<body>
"""
for page_num in pages_sorted:
page_vocab = [v for v in vocab_list if v.source_page == page_num]
page_sentences = [s for s in sentence_list if s.source_page == page_num]
# PAGE 1: VOCABULARY TABLE
if page_vocab:
combined_html += f"""
<div class="page">
<div class="header">
<h1>{title} - Vokabeln (Seite {page_num})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
<table class="vocab-table">
<thead>
<tr>
<th class="col-english">Englisch</th>
<th class="col-german">Deutsch</th>
<th class="col-correction">Korrektur</th>
</tr>
</thead>
<tbody>
"""
for v in page_vocab:
if show_solutions:
combined_html += f"""
<tr>
<td>{v.english}</td>
<td class="vocab-answer">{v.german}</td>
<td></td>
</tr>
"""
else:
combined_html += f"""
<tr>
<td>{v.english}</td>
<td></td>
<td></td>
</tr>
"""
combined_html += f"""
</tbody>
</table>
<div class="page-info">{title} - Seite {page_num}</div>
</div>
"""
# PAGE 2: SENTENCE PRACTICE
if page_sentences:
combined_html += f"""
<div class="page">
<div class="header">
<h1>{title} - Lernsaetze (Seite {page_num})</h1>
<div class="name-line">Name: _________________________ Datum: _____________</div>
</div>
"""
for s in page_sentences:
combined_html += f"""
<table class="sentence-table">
<tr>
<td class="sentence-header">{s.german}</td>
</tr>
"""
if show_solutions:
combined_html += f"""
<tr>
<td class="sentence-line sentence-answer">{s.english}</td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
else:
combined_html += """
<tr>
<td class="sentence-line"></td>
</tr>
<tr>
<td class="sentence-line"></td>
</tr>
"""
combined_html += """
</table>
"""
combined_html += f"""
<div class="page-info">{title} - Seite {page_num}</div>
</div>
"""
combined_html += """
</body>
</html>
"""
return combined_html
@@ -0,0 +1,70 @@
"""
NRU Worksheet Models — data classes and entry separation logic.
Extracted from nru_worksheet_generator.py for modularity.
"""
import logging
from typing import List, Dict, Tuple
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class VocabEntry:
english: str
german: str
source_page: int = 1
@dataclass
class SentenceEntry:
german: str
english: str # For solution sheet
source_page: int = 1
def separate_vocab_and_sentences(entries: List[Dict]) -> Tuple[List[VocabEntry], List[SentenceEntry]]:
"""
Separate vocabulary entries into single words/phrases and full sentences.
Sentences are identified by:
- Ending with punctuation (. ! ?)
- Being longer than 40 characters
- Containing multiple words with capital letters mid-sentence
"""
vocab_list = []
sentence_list = []
for entry in entries:
english = entry.get("english", "").strip()
german = entry.get("german", "").strip()
source_page = entry.get("source_page", 1)
if not english or not german:
continue
# Detect if this is a sentence
is_sentence = (
english.endswith('.') or
english.endswith('!') or
english.endswith('?') or
len(english) > 50 or
(len(english.split()) > 5 and any(w[0].isupper() for w in english.split()[1:] if w))
)
if is_sentence:
sentence_list.append(SentenceEntry(
german=german,
english=english,
source_page=source_page
))
else:
vocab_list.append(VocabEntry(
english=english,
german=german,
source_page=source_page
))
return vocab_list, sentence_list
@@ -0,0 +1,31 @@
"""
NRU Worksheet PDF — PDF generation using weasyprint.
Extracted from nru_worksheet_generator.py for modularity.
"""
from typing import List, Dict, Tuple
from .nru_html import generate_nru_worksheet_html
async def generate_nru_pdf(entries: List[Dict], title: str = "Vokabeltest", include_solutions: bool = True) -> Tuple[bytes, bytes]:
"""
Generate NRU worksheet PDFs.
Returns:
Tuple of (worksheet_pdf_bytes, solution_pdf_bytes)
"""
from weasyprint import HTML
# Generate worksheet HTML
worksheet_html = generate_nru_worksheet_html(entries, title, show_solutions=False)
worksheet_pdf = HTML(string=worksheet_html).write_pdf()
# Generate solution HTML
solution_pdf = None
if include_solutions:
solution_html = generate_nru_worksheet_html(entries, title, show_solutions=True)
solution_pdf = HTML(string=solution_html).write_pdf()
return worksheet_pdf, solution_pdf
@@ -1,491 +1,4 @@
"""
Worksheet Cleanup API - Handschrift-Entfernung und Layout-Rekonstruktion
Endpoints:
- POST /api/v1/worksheet/detect-handwriting - Erkennt Handschrift und gibt Maske zurueck
- POST /api/v1/worksheet/remove-handwriting - Entfernt Handschrift aus Bild
- POST /api/v1/worksheet/reconstruct - Rekonstruiert Layout als Fabric.js JSON
- POST /api/v1/worksheet/cleanup-pipeline - Vollstaendige Pipeline (Erkennung + Entfernung + Layout)
DATENSCHUTZ: Alle Verarbeitung erfolgt lokal auf dem Mac Mini.
"""
import io
import base64
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from services.handwriting_detection import (
detect_handwriting,
detect_handwriting_regions,
mask_to_png
)
from services.inpainting_service import (
inpaint_image,
remove_handwriting,
InpaintingMethod,
check_lama_available
)
from services.layout_reconstruction_service import (
reconstruct_layout,
layout_to_fabric_json,
reconstruct_and_clean
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/worksheet", tags=["Worksheet Cleanup"])
# =============================================================================
# Pydantic Models
# =============================================================================
class DetectionResponse(BaseModel):
has_handwriting: bool
confidence: float
handwriting_ratio: float
detection_method: str
mask_base64: Optional[str] = None
class InpaintingResponse(BaseModel):
success: bool
method_used: str
processing_time_ms: float
image_base64: Optional[str] = None
error: Optional[str] = None
class ReconstructionResponse(BaseModel):
success: bool
element_count: int
page_width: int
page_height: int
fabric_json: dict
table_count: int = 0
class PipelineResponse(BaseModel):
success: bool
handwriting_detected: bool
handwriting_removed: bool
layout_reconstructed: bool
cleaned_image_base64: Optional[str] = None
fabric_json: Optional[dict] = None
metadata: dict = {}
class CapabilitiesResponse(BaseModel):
opencv_available: bool = True
lama_available: bool = False
paddleocr_available: bool = False
# =============================================================================
# API Endpoints
# =============================================================================
@router.get("/capabilities")
async def get_capabilities() -> CapabilitiesResponse:
"""
Get available cleanup capabilities on this server.
"""
# Check PaddleOCR
paddleocr_available = False
try:
from hybrid_vocab_extractor import get_paddle_ocr
ocr = get_paddle_ocr()
paddleocr_available = ocr is not None
except Exception:
pass
return CapabilitiesResponse(
opencv_available=True,
lama_available=check_lama_available(),
paddleocr_available=paddleocr_available
)
@router.post("/detect-handwriting")
async def detect_handwriting_endpoint(
image: UploadFile = File(...),
return_mask: bool = Form(default=True),
min_confidence: float = Form(default=0.3)
) -> DetectionResponse:
"""
Detect handwriting in an image.
Args:
image: Input image (PNG, JPG)
return_mask: Whether to return the binary mask as base64
min_confidence: Minimum confidence threshold
Returns:
DetectionResponse with detection results and optional mask
"""
logger.info(f"Handwriting detection request: {image.filename}")
# Validate file type
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files (PNG, JPG) are supported"
)
try:
image_bytes = await image.read()
# Detect handwriting
result = detect_handwriting(image_bytes)
has_handwriting = (
result.confidence >= min_confidence and
result.handwriting_ratio > 0.005
)
response = DetectionResponse(
has_handwriting=has_handwriting,
confidence=result.confidence,
handwriting_ratio=result.handwriting_ratio,
detection_method=result.detection_method
)
if return_mask:
mask_bytes = mask_to_png(result.mask)
response.mask_base64 = base64.b64encode(mask_bytes).decode('utf-8')
logger.info(f"Detection complete: handwriting={has_handwriting}, "
f"confidence={result.confidence:.2f}")
return response
except Exception as e:
logger.error(f"Handwriting detection failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/detect-handwriting/mask")
async def get_handwriting_mask(
image: UploadFile = File(...)
) -> StreamingResponse:
"""
Get handwriting detection mask as PNG image.
Returns binary mask where white (255) = handwriting.
"""
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
result = detect_handwriting(image_bytes)
mask_bytes = mask_to_png(result.mask)
return StreamingResponse(
io.BytesIO(mask_bytes),
media_type="image/png",
headers={
"Content-Disposition": "attachment; filename=handwriting_mask.png"
}
)
except Exception as e:
logger.error(f"Mask generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/remove-handwriting")
async def remove_handwriting_endpoint(
image: UploadFile = File(...),
mask: Optional[UploadFile] = File(default=None),
method: str = Form(default="auto"),
return_base64: bool = Form(default=False)
):
"""
Remove handwriting from an image.
Args:
image: Input image with handwriting
mask: Optional pre-computed mask (if not provided, auto-detected)
method: Inpainting method (auto, opencv_telea, opencv_ns, lama)
return_base64: If True, return image as base64, else as file
Returns:
Cleaned image (as PNG file or base64 in JSON)
"""
logger.info(f"Remove handwriting request: {image.filename}, method={method}")
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
# Get mask if provided
mask_array = None
if mask is not None:
mask_bytes = await mask.read()
from PIL import Image
import numpy as np
mask_img = Image.open(io.BytesIO(mask_bytes))
mask_array = np.array(mask_img)
# Select inpainting method
inpainting_method = InpaintingMethod.AUTO
if method == "opencv_telea":
inpainting_method = InpaintingMethod.OPENCV_TELEA
elif method == "opencv_ns":
inpainting_method = InpaintingMethod.OPENCV_NS
elif method == "lama":
inpainting_method = InpaintingMethod.LAMA
# Remove handwriting
cleaned_bytes, metadata = remove_handwriting(
image_bytes,
mask=mask_array,
method=inpainting_method
)
if return_base64:
return JSONResponse({
"success": True,
"image_base64": base64.b64encode(cleaned_bytes).decode('utf-8'),
"metadata": metadata
})
else:
return StreamingResponse(
io.BytesIO(cleaned_bytes),
media_type="image/png",
headers={
"Content-Disposition": "attachment; filename=cleaned.png",
"X-Method-Used": metadata.get("method_used", "unknown"),
"X-Processing-Time-Ms": str(metadata.get("processing_time_ms", 0))
}
)
except Exception as e:
logger.error(f"Handwriting removal failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/reconstruct")
async def reconstruct_layout_endpoint(
image: UploadFile = File(...),
clean_handwriting: bool = Form(default=True),
detect_tables: bool = Form(default=True)
) -> ReconstructionResponse:
"""
Reconstruct worksheet layout and generate Fabric.js JSON.
Args:
image: Input image (can contain handwriting)
clean_handwriting: Whether to remove handwriting first
detect_tables: Whether to detect table structures
Returns:
ReconstructionResponse with Fabric.js JSON
"""
logger.info(f"Layout reconstruction request: {image.filename}")
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
# Run reconstruction pipeline
if clean_handwriting:
cleaned_bytes, layout = reconstruct_and_clean(image_bytes)
else:
layout = reconstruct_layout(image_bytes, detect_tables=detect_tables)
return ReconstructionResponse(
success=True,
element_count=len(layout.elements),
page_width=layout.page_width,
page_height=layout.page_height,
fabric_json=layout.fabric_json,
table_count=len(layout.table_regions)
)
except Exception as e:
logger.error(f"Layout reconstruction failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cleanup-pipeline")
async def full_cleanup_pipeline(
image: UploadFile = File(...),
remove_hw: bool = Form(default=True, alias="remove_handwriting"),
reconstruct: bool = Form(default=True),
inpainting_method: str = Form(default="auto")
) -> PipelineResponse:
"""
Full cleanup pipeline: detect, remove handwriting, reconstruct layout.
This is the recommended endpoint for processing filled worksheets.
Args:
image: Input image (scan/photo of filled worksheet)
remove_handwriting: Whether to remove detected handwriting
reconstruct: Whether to reconstruct layout as Fabric.js JSON
inpainting_method: Method for inpainting (auto, opencv_telea, opencv_ns, lama)
Returns:
PipelineResponse with cleaned image and Fabric.js JSON
"""
logger.info(f"Full cleanup pipeline: {image.filename}")
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
metadata = {}
# Step 1: Detect handwriting
detection = detect_handwriting(image_bytes)
handwriting_detected = (
detection.confidence >= 0.3 and
detection.handwriting_ratio > 0.005
)
metadata["detection"] = {
"confidence": detection.confidence,
"handwriting_ratio": detection.handwriting_ratio,
"method": detection.detection_method
}
# Step 2: Remove handwriting if requested and detected
cleaned_bytes = image_bytes
handwriting_removed = False
if remove_hw and handwriting_detected:
method = InpaintingMethod.AUTO
if inpainting_method == "opencv_telea":
method = InpaintingMethod.OPENCV_TELEA
elif inpainting_method == "opencv_ns":
method = InpaintingMethod.OPENCV_NS
elif inpainting_method == "lama":
method = InpaintingMethod.LAMA
cleaned_bytes, inpaint_metadata = remove_handwriting(
image_bytes,
mask=detection.mask,
method=method
)
handwriting_removed = inpaint_metadata.get("inpainting_performed", False)
metadata["inpainting"] = inpaint_metadata
# Step 3: Reconstruct layout if requested
fabric_json = None
layout_reconstructed = False
if reconstruct:
layout = reconstruct_layout(cleaned_bytes)
fabric_json = layout.fabric_json
layout_reconstructed = len(layout.elements) > 0
metadata["layout"] = {
"element_count": len(layout.elements),
"table_count": len(layout.table_regions),
"page_width": layout.page_width,
"page_height": layout.page_height
}
# Encode cleaned image as base64
cleaned_base64 = base64.b64encode(cleaned_bytes).decode('utf-8')
logger.info(f"Pipeline complete: detected={handwriting_detected}, "
f"removed={handwriting_removed}, layout={layout_reconstructed}")
return PipelineResponse(
success=True,
handwriting_detected=handwriting_detected,
handwriting_removed=handwriting_removed,
layout_reconstructed=layout_reconstructed,
cleaned_image_base64=cleaned_base64,
fabric_json=fabric_json,
metadata=metadata
)
except Exception as e:
logger.error(f"Cleanup pipeline failed: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.post("/preview-cleanup")
async def preview_cleanup(
image: UploadFile = File(...)
) -> JSONResponse:
"""
Quick preview of cleanup results without full processing.
Returns detection results and estimated processing time.
"""
content_type = image.content_type or ""
if not content_type.startswith("image/"):
raise HTTPException(
status_code=400,
detail="Only image files are supported"
)
try:
image_bytes = await image.read()
# Quick detection only
result = detect_handwriting_regions(image_bytes)
# Estimate processing time based on image size
from PIL import Image
img = Image.open(io.BytesIO(image_bytes))
pixel_count = img.width * img.height
# Rough estimates
est_detection_ms = 100 + (pixel_count / 1000000) * 200
est_inpainting_ms = 500 + (pixel_count / 1000000) * 1000
est_reconstruction_ms = 200 + (pixel_count / 1000000) * 300
return JSONResponse({
"has_handwriting": result["has_handwriting"],
"confidence": result["confidence"],
"handwriting_ratio": result["handwriting_ratio"],
"image_width": img.width,
"image_height": img.height,
"estimated_times_ms": {
"detection": est_detection_ms,
"inpainting": est_inpainting_ms if result["has_handwriting"] else 0,
"reconstruction": est_reconstruction_ms,
"total": est_detection_ms + (est_inpainting_ms if result["has_handwriting"] else 0) + est_reconstruction_ms
},
"capabilities": {
"lama_available": check_lama_available()
}
})
except Exception as e:
logger.error(f"Preview failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Backward-compat shim -- module moved to worksheet/cleanup_api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.cleanup_api")
+4 -485
View File
@@ -1,485 +1,4 @@
"""
Worksheet Editor AI — AI image generation and AI worksheet modification.
"""
import io
import json
import base64
import logging
import re
import time
import random
from typing import List, Dict
import httpx
from worksheet_editor_models import (
AIImageRequest,
AIImageResponse,
AIImageStyle,
AIModifyRequest,
AIModifyResponse,
OLLAMA_URL,
STYLE_PROMPTS,
)
logger = logging.getLogger(__name__)
# =============================================
# AI IMAGE GENERATION
# =============================================
async def generate_ai_image_logic(request: AIImageRequest) -> AIImageResponse:
"""
Generate an AI image using Ollama with a text-to-image model.
Falls back to a placeholder if Ollama is not available.
"""
from fastapi import HTTPException
try:
# Build enhanced prompt with style
style_modifier = STYLE_PROMPTS.get(request.style, "")
enhanced_prompt = f"{request.prompt}, {style_modifier}"
logger.info(f"Generating AI image: {enhanced_prompt[:100]}...")
# Check if Ollama is available
async with httpx.AsyncClient(timeout=10.0) as check_client:
try:
health_response = await check_client.get(f"{OLLAMA_URL}/api/tags")
if health_response.status_code != 200:
raise HTTPException(status_code=503, detail="Ollama service not available")
except httpx.ConnectError:
logger.warning("Ollama not reachable, returning placeholder")
return _generate_placeholder_image(request, enhanced_prompt)
try:
async with httpx.AsyncClient(timeout=300.0) as client:
tags_response = await client.get(f"{OLLAMA_URL}/api/tags")
available_models = [m.get("name", "") for m in tags_response.json().get("models", [])]
sd_model = None
for model in available_models:
if "stable" in model.lower() or "sd" in model.lower() or "diffusion" in model.lower():
sd_model = model
break
if not sd_model:
logger.warning("No Stable Diffusion model found in Ollama")
return _generate_placeholder_image(request, enhanced_prompt)
logger.info(f"SD model found: {sd_model}, but image generation API not implemented")
return _generate_placeholder_image(request, enhanced_prompt)
except Exception as e:
logger.error(f"Image generation failed: {e}")
return _generate_placeholder_image(request, enhanced_prompt)
except HTTPException:
raise
except Exception as e:
logger.error(f"AI image generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _generate_placeholder_image(request: AIImageRequest, prompt: str) -> AIImageResponse:
"""
Generate a placeholder image when AI generation is not available.
Creates a simple SVG-based placeholder with the prompt text.
"""
from PIL import Image, ImageDraw, ImageFont
width, height = request.width, request.height
style_colors = {
AIImageStyle.REALISTIC: ("#2563eb", "#dbeafe"),
AIImageStyle.CARTOON: ("#f97316", "#ffedd5"),
AIImageStyle.SKETCH: ("#6b7280", "#f3f4f6"),
AIImageStyle.CLIPART: ("#8b5cf6", "#ede9fe"),
AIImageStyle.EDUCATIONAL: ("#059669", "#d1fae5"),
}
fg_color, bg_color = style_colors.get(request.style, ("#6366f1", "#e0e7ff"))
img = Image.new('RGB', (width, height), bg_color)
draw = ImageDraw.Draw(img)
draw.rectangle([5, 5, width-6, height-6], outline=fg_color, width=3)
cx, cy = width // 2, height // 2 - 30
draw.ellipse([cx-40, cy-40, cx+40, cy+40], outline=fg_color, width=3)
draw.line([cx-20, cy-10, cx+20, cy-10], fill=fg_color, width=3)
draw.line([cx, cy-10, cx, cy+20], fill=fg_color, width=3)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
except Exception:
font = ImageFont.load_default()
max_chars = 40
lines = []
words = prompt[:200].split()
current_line = ""
for word in words:
if len(current_line) + len(word) + 1 <= max_chars:
current_line += (" " + word if current_line else word)
else:
if current_line:
lines.append(current_line)
current_line = word
if current_line:
lines.append(current_line)
text_y = cy + 60
for line in lines[:4]:
bbox = draw.textbbox((0, 0), line, font=font)
text_width = bbox[2] - bbox[0]
draw.text((cx - text_width // 2, text_y), line, fill=fg_color, font=font)
text_y += 20
badge_text = "KI-Bild (Platzhalter)"
try:
badge_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 10)
except Exception:
badge_font = font
draw.rectangle([10, height-30, 150, height-10], fill=fg_color)
draw.text((15, height-27), badge_text, fill="white", font=badge_font)
buffer = io.BytesIO()
img.save(buffer, format='PNG')
buffer.seek(0)
image_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
return AIImageResponse(
image_base64=image_base64,
prompt_used=prompt,
error="AI image generation not available. Using placeholder."
)
# =============================================
# AI WORKSHEET MODIFICATION
# =============================================
async def modify_worksheet_with_ai_logic(request: AIModifyRequest) -> AIModifyResponse:
"""
Modify a worksheet using AI based on natural language prompt.
"""
try:
logger.info(f"AI modify request: {request.prompt[:100]}...")
try:
canvas_data = json.loads(request.canvas_json)
except json.JSONDecodeError:
return AIModifyResponse(
message="Fehler beim Parsen des Canvas",
error="Invalid canvas JSON"
)
system_prompt = """Du bist ein Assistent fuer die Bearbeitung von Arbeitsblaettern.
Du erhaeltst den aktuellen Zustand eines Canvas im JSON-Format und eine Anweisung des Nutzers.
Deine Aufgabe ist es, die gewuenschten Aenderungen am Canvas vorzunehmen.
Der Canvas verwendet Fabric.js. Hier sind die wichtigsten Objekttypen:
- i-text: Interaktiver Text mit fontFamily, fontSize, fill, left, top
- rect: Rechteck mit left, top, width, height, fill, stroke, strokeWidth
- circle: Kreis mit left, top, radius, fill, stroke, strokeWidth
- line: Linie mit x1, y1, x2, y2, stroke, strokeWidth
Das Canvas ist 794x1123 Pixel (A4 bei 96 DPI).
Antworte NUR mit einem JSON-Objekt in diesem Format:
{
"action": "modify" oder "add" oder "delete" oder "info",
"objects": [...], // Neue/modifizierte Objekte (bei modify/add)
"message": "Kurze Beschreibung der Aenderung"
}
Wenn du Objekte hinzufuegst, generiere eindeutige IDs im Format "obj_<timestamp>_<random>".
"""
user_prompt = f"""Aktueller Canvas-Zustand:
```json
{json.dumps(canvas_data, indent=2)[:5000]}
```
Nutzer-Anweisung: {request.prompt}
Fuehre die Aenderung durch und antworte mit dem JSON-Objekt."""
try:
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{OLLAMA_URL}/api/generate",
json={
"model": request.model,
"prompt": user_prompt,
"system": system_prompt,
"stream": False,
"options": {
"temperature": 0.3,
"num_predict": 4096
}
}
)
if response.status_code != 200:
logger.warning(f"Ollama error: {response.status_code}, trying local fallback")
return _handle_simple_modification(request.prompt, canvas_data)
ai_response = response.json().get("response", "")
except httpx.ConnectError:
logger.warning("Ollama not reachable")
return _handle_simple_modification(request.prompt, canvas_data)
except httpx.TimeoutException:
logger.warning("Ollama timeout, trying local fallback")
return _handle_simple_modification(request.prompt, canvas_data)
try:
json_start = ai_response.find('{')
json_end = ai_response.rfind('}') + 1
if json_start == -1 or json_end <= json_start:
logger.warning(f"No JSON found in AI response: {ai_response[:200]}")
return AIModifyResponse(
message="KI konnte die Anfrage nicht verarbeiten",
error="No JSON in response"
)
ai_json = json.loads(ai_response[json_start:json_end])
action = ai_json.get("action", "info")
message = ai_json.get("message", "Aenderungen angewendet")
new_objects = ai_json.get("objects", [])
if action == "info":
return AIModifyResponse(message=message)
if action == "add" and new_objects:
existing_objects = canvas_data.get("objects", [])
existing_objects.extend(new_objects)
canvas_data["objects"] = existing_objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
if action == "modify" and new_objects:
existing_objects = canvas_data.get("objects", [])
new_ids = {obj.get("id") for obj in new_objects if obj.get("id")}
kept_objects = [obj for obj in existing_objects if obj.get("id") not in new_ids]
kept_objects.extend(new_objects)
canvas_data["objects"] = kept_objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
if action == "delete":
delete_ids = ai_json.get("delete_ids", [])
if delete_ids:
existing_objects = canvas_data.get("objects", [])
canvas_data["objects"] = [obj for obj in existing_objects if obj.get("id") not in delete_ids]
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=message
)
return AIModifyResponse(message=message)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse AI JSON: {e}")
return AIModifyResponse(
message="Fehler beim Verarbeiten der KI-Antwort",
error=str(e)
)
except Exception as e:
logger.error(f"AI modify error: {e}")
return AIModifyResponse(
message="Ein unerwarteter Fehler ist aufgetreten",
error=str(e)
)
def _handle_simple_modification(prompt: str, canvas_data: dict) -> AIModifyResponse:
"""
Handle simple modifications locally when Ollama is not available.
Supports basic commands like adding headings, lines, etc.
"""
prompt_lower = prompt.lower()
objects = canvas_data.get("objects", [])
def generate_id():
return f"obj_{int(time.time()*1000)}_{random.randint(1000, 9999)}"
# Add heading
if "ueberschrift" in prompt_lower or "titel" in prompt_lower or "heading" in prompt_lower:
text_match = re.search(r'"([^"]+)"', prompt)
text = text_match.group(1) if text_match else "Ueberschrift"
new_text = {
"type": "i-text", "id": generate_id(), "text": text,
"left": 397, "top": 50, "originX": "center",
"fontFamily": "Arial", "fontSize": 28, "fontWeight": "bold", "fill": "#000000"
}
objects.append(new_text)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"Ueberschrift '{text}' hinzugefuegt"
)
# Add lines for writing
if "linie" in prompt_lower or "line" in prompt_lower or "schreib" in prompt_lower:
num_match = re.search(r'(\d+)', prompt)
num_lines = int(num_match.group(1)) if num_match else 5
num_lines = min(num_lines, 20)
start_y = 150
line_spacing = 40
for i in range(num_lines):
new_line = {
"type": "line", "id": generate_id(),
"x1": 60, "y1": start_y + i * line_spacing,
"x2": 734, "y2": start_y + i * line_spacing,
"stroke": "#cccccc", "strokeWidth": 1
}
objects.append(new_line)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{num_lines} Schreiblinien hinzugefuegt"
)
# Make text bigger
if "groesser" in prompt_lower or "bigger" in prompt_lower or "larger" in prompt_lower:
modified = 0
for obj in objects:
if obj.get("type") in ["i-text", "text", "textbox"]:
current_size = obj.get("fontSize", 16)
obj["fontSize"] = int(current_size * 1.25)
modified += 1
canvas_data["objects"] = objects
if modified > 0:
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{modified} Texte vergroessert"
)
# Center elements
if "zentrier" in prompt_lower or "center" in prompt_lower or "mitte" in prompt_lower:
center_x = 397
for obj in objects:
if not obj.get("isGrid"):
obj["left"] = center_x
obj["originX"] = "center"
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message="Elemente zentriert"
)
# Add numbering
if "nummer" in prompt_lower or "nummerier" in prompt_lower or "1-10" in prompt_lower:
range_match = re.search(r'(\d+)\s*[-bis]+\s*(\d+)', prompt)
if range_match:
start, end = int(range_match.group(1)), int(range_match.group(2))
else:
start, end = 1, 10
y = 100
for i in range(start, min(end + 1, start + 20)):
new_text = {
"type": "i-text", "id": generate_id(), "text": f"{i}.",
"left": 40, "top": y, "fontFamily": "Arial", "fontSize": 14, "fill": "#000000"
}
objects.append(new_text)
y += 35
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"Nummerierung {start}-{end} hinzugefuegt"
)
# Add rectangle/box
if "rechteck" in prompt_lower or "box" in prompt_lower or "kasten" in prompt_lower:
new_rect = {
"type": "rect", "id": generate_id(),
"left": 100, "top": 200, "width": 200, "height": 100,
"fill": "transparent", "stroke": "#000000", "strokeWidth": 2
}
objects.append(new_rect)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message="Rechteck hinzugefuegt"
)
# Add grid/raster
if "raster" in prompt_lower or "grid" in prompt_lower or "tabelle" in prompt_lower:
dim_match = re.search(r'(\d+)\s*[x/\u00d7\*mal by]\s*(\d+)', prompt_lower)
if dim_match:
cols = int(dim_match.group(1))
rows = int(dim_match.group(2))
else:
nums = re.findall(r'(\d+)', prompt)
if len(nums) >= 2:
cols, rows = int(nums[0]), int(nums[1])
else:
cols, rows = 3, 4
cols = min(max(1, cols), 10)
rows = min(max(1, rows), 15)
canvas_width = 794
canvas_height = 1123
margin = 60
available_width = canvas_width - 2 * margin
available_height = canvas_height - 2 * margin - 80
cell_width = available_width / cols
cell_height = min(available_height / rows, 80)
start_x = margin
start_y = 120
grid_objects = []
for r in range(rows + 1):
y = start_y + r * cell_height
grid_objects.append({
"type": "line", "id": generate_id(),
"x1": start_x, "y1": y,
"x2": start_x + cols * cell_width, "y2": y,
"stroke": "#666666", "strokeWidth": 1, "isGrid": True
})
for c in range(cols + 1):
x = start_x + c * cell_width
grid_objects.append({
"type": "line", "id": generate_id(),
"x1": x, "y1": start_y,
"x2": x, "y2": start_y + rows * cell_height,
"stroke": "#666666", "strokeWidth": 1, "isGrid": True
})
objects.extend(grid_objects)
canvas_data["objects"] = objects
return AIModifyResponse(
modified_canvas_json=json.dumps(canvas_data),
message=f"{cols}x{rows} Raster hinzugefuegt ({cols} Spalten, {rows} Zeilen)"
)
# Default: Ollama needed
return AIModifyResponse(
message="Diese Aenderung erfordert den KI-Service. Bitte stellen Sie sicher, dass Ollama laeuft.",
error="Complex modification requires Ollama"
)
# Backward-compat shim -- module moved to worksheet/editor_ai.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.editor_ai")
+4 -388
View File
@@ -1,388 +1,4 @@
"""
Worksheet Editor API - Backend Endpoints for Visual Worksheet Editor
Provides endpoints for:
- AI Image generation via Ollama/Stable Diffusion
- Worksheet Save/Load
- PDF Export
Split modules:
- worksheet_editor_models: Enums, Pydantic models, configuration
- worksheet_editor_ai: AI image generation and AI worksheet modification
- worksheet_editor_reconstruct: Document reconstruction from vocab sessions
"""
import os
import io
import json
import logging
from datetime import datetime, timezone
import uuid
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
import httpx
# Re-export everything from sub-modules for backward compatibility
from worksheet_editor_models import ( # noqa: F401
AIImageStyle,
WorksheetStatus,
AIImageRequest,
AIImageResponse,
PageData,
PageFormat,
WorksheetSaveRequest,
WorksheetResponse,
AIModifyRequest,
AIModifyResponse,
ReconstructRequest,
ReconstructResponse,
worksheets_db,
OLLAMA_URL,
SD_MODEL,
WORKSHEET_STORAGE_DIR,
STYLE_PROMPTS,
REPORTLAB_AVAILABLE,
)
from worksheet_editor_ai import ( # noqa: F401
generate_ai_image_logic,
_generate_placeholder_image,
modify_worksheet_with_ai_logic,
_handle_simple_modification,
)
from worksheet_editor_reconstruct import ( # noqa: F401
reconstruct_document_logic,
_detect_image_regions,
)
logger = logging.getLogger(__name__)
# =============================================
# ROUTER
# =============================================
router = APIRouter(prefix="/api/v1/worksheet", tags=["Worksheet Editor"])
# =============================================
# AI IMAGE GENERATION
# =============================================
@router.post("/ai-image", response_model=AIImageResponse)
async def generate_ai_image(request: AIImageRequest):
"""
Generate an AI image using Ollama with a text-to-image model.
Supported models:
- stable-diffusion (via Ollama)
- sd3.5-medium
- llava (for image understanding, not generation)
Falls back to a placeholder if Ollama is not available.
"""
return await generate_ai_image_logic(request)
# =============================================
# WORKSHEET SAVE/LOAD
# =============================================
@router.post("/save", response_model=WorksheetResponse)
async def save_worksheet(request: WorksheetSaveRequest):
"""
Save a worksheet document.
- If id is provided, updates existing worksheet
- If id is not provided, creates new worksheet
"""
try:
now = datetime.now(timezone.utc).isoformat()
worksheet_id = request.id or f"ws_{uuid.uuid4().hex[:12]}"
worksheet = {
"id": worksheet_id,
"title": request.title,
"description": request.description,
"pages": [p.dict() for p in request.pages],
"pageFormat": (request.pageFormat or PageFormat()).dict(),
"createdAt": worksheets_db.get(worksheet_id, {}).get("createdAt", now),
"updatedAt": now
}
worksheets_db[worksheet_id] = worksheet
filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json")
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(worksheet, f, ensure_ascii=False, indent=2)
logger.info(f"Saved worksheet: {worksheet_id}")
return WorksheetResponse(**worksheet)
except Exception as e:
logger.error(f"Failed to save worksheet: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save: {str(e)}")
@router.get("/{worksheet_id}", response_model=WorksheetResponse)
async def get_worksheet(worksheet_id: str):
"""Load a worksheet document by ID."""
try:
if worksheet_id in worksheets_db:
return WorksheetResponse(**worksheets_db[worksheet_id])
filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json")
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
worksheet = json.load(f)
worksheets_db[worksheet_id] = worksheet
return WorksheetResponse(**worksheet)
raise HTTPException(status_code=404, detail="Worksheet not found")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to load worksheet {worksheet_id}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to load: {str(e)}")
@router.get("/list/all")
async def list_worksheets():
"""List all available worksheets."""
try:
worksheets = []
for filename in os.listdir(WORKSHEET_STORAGE_DIR):
if filename.endswith('.json'):
filepath = os.path.join(WORKSHEET_STORAGE_DIR, filename)
try:
with open(filepath, 'r', encoding='utf-8') as f:
worksheet = json.load(f)
worksheets.append({
"id": worksheet.get("id"),
"title": worksheet.get("title"),
"description": worksheet.get("description"),
"pageCount": len(worksheet.get("pages", [])),
"updatedAt": worksheet.get("updatedAt"),
"createdAt": worksheet.get("createdAt")
})
except Exception as e:
logger.warning(f"Failed to load {filename}: {e}")
worksheets.sort(key=lambda x: x.get("updatedAt", ""), reverse=True)
return {"worksheets": worksheets, "total": len(worksheets)}
except Exception as e:
logger.error(f"Failed to list worksheets: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{worksheet_id}")
async def delete_worksheet(worksheet_id: str):
"""Delete a worksheet document."""
try:
if worksheet_id in worksheets_db:
del worksheets_db[worksheet_id]
filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json")
if os.path.exists(filepath):
os.remove(filepath)
logger.info(f"Deleted worksheet: {worksheet_id}")
return {"status": "deleted", "id": worksheet_id}
raise HTTPException(status_code=404, detail="Worksheet not found")
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to delete worksheet {worksheet_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
# =============================================
# PDF EXPORT
# =============================================
@router.post("/{worksheet_id}/export-pdf")
async def export_worksheet_pdf(worksheet_id: str):
"""
Export worksheet as PDF.
Note: This creates a basic PDF. For full canvas rendering,
the frontend should use pdf-lib with canvas.toDataURL().
"""
if not REPORTLAB_AVAILABLE:
raise HTTPException(status_code=501, detail="PDF export not available (reportlab not installed)")
try:
from reportlab.lib.pagesizes import A4
from reportlab.pdfgen import canvas
worksheet = worksheets_db.get(worksheet_id)
if not worksheet:
filepath = os.path.join(WORKSHEET_STORAGE_DIR, f"{worksheet_id}.json")
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
worksheet = json.load(f)
else:
raise HTTPException(status_code=404, detail="Worksheet not found")
buffer = io.BytesIO()
c = canvas.Canvas(buffer, pagesize=A4)
page_width, page_height = A4
for page_data in worksheet.get("pages", []):
if page_data.get("index", 0) == 0:
c.setFont("Helvetica-Bold", 18)
c.drawString(50, page_height - 50, worksheet.get("title", "Arbeitsblatt"))
c.setFont("Helvetica", 10)
c.drawString(50, page_height - 70, f"Erstellt: {worksheet.get('createdAt', '')[:10]}")
canvas_json_str = page_data.get("canvasJSON", "{}")
if canvas_json_str:
try:
canvas_data = json.loads(canvas_json_str)
objects = canvas_data.get("objects", [])
for obj in objects:
obj_type = obj.get("type", "")
if obj_type in ["text", "i-text", "textbox"]:
text = obj.get("text", "")
left = obj.get("left", 50)
top = obj.get("top", 100)
font_size = obj.get("fontSize", 12)
pdf_x = left * 0.75
pdf_y = page_height - (top * 0.75)
c.setFont("Helvetica", min(font_size, 24))
c.drawString(pdf_x, pdf_y, text[:100])
elif obj_type == "rect":
left = obj.get("left", 0) * 0.75
top = obj.get("top", 0) * 0.75
width = obj.get("width", 50) * 0.75
height = obj.get("height", 30) * 0.75
c.rect(left, page_height - top - height, width, height)
elif obj_type == "circle":
left = obj.get("left", 0) * 0.75
top = obj.get("top", 0) * 0.75
radius = obj.get("radius", 25) * 0.75
c.circle(left + radius, page_height - top - radius, radius)
except json.JSONDecodeError:
pass
c.showPage()
c.save()
buffer.seek(0)
filename = f"{worksheet.get('title', 'worksheet').replace(' ', '_')}.pdf"
return StreamingResponse(
buffer,
media_type="application/pdf",
headers={"Content-Disposition": f"attachment; filename={filename}"}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"PDF export failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
# =============================================
# AI WORKSHEET MODIFICATION
# =============================================
@router.post("/ai-modify", response_model=AIModifyResponse)
async def modify_worksheet_with_ai(request: AIModifyRequest):
"""
Modify a worksheet using AI based on natural language prompt.
Uses Ollama with qwen2.5vl:32b to understand the canvas state
and generate modifications based on the user's request.
"""
return await modify_worksheet_with_ai_logic(request)
# =============================================
# HEALTH CHECK
# =============================================
@router.get("/health/check")
async def health_check():
"""Check worksheet editor API health and dependencies."""
status = {
"status": "healthy",
"ollama": False,
"storage": os.path.exists(WORKSHEET_STORAGE_DIR),
"reportlab": REPORTLAB_AVAILABLE,
"worksheets_count": len(worksheets_db)
}
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{OLLAMA_URL}/api/tags")
status["ollama"] = response.status_code == 200
except Exception:
pass
return status
# =============================================
# DOCUMENT RECONSTRUCTION FROM VOCAB SESSION
# =============================================
@router.post("/reconstruct-from-session", response_model=ReconstructResponse)
async def reconstruct_document_from_session(request: ReconstructRequest):
"""
Reconstruct a document from a vocab session into Fabric.js canvas format.
Returns canvas JSON ready to load into the worksheet editor.
"""
try:
return await reconstruct_document_logic(request)
except HTTPException:
raise
except Exception as e:
logger.error(f"Document reconstruction failed: {e}")
import traceback
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@router.get("/sessions/available")
async def get_available_sessions():
"""Get list of available vocab sessions that can be reconstructed."""
try:
from vocab_worksheet_api import _sessions
available = []
for session_id, session in _sessions.items():
if session.get("pdf_data"):
available.append({
"id": session_id,
"name": session.get("name", "Unnamed"),
"description": session.get("description"),
"vocabulary_count": len(session.get("vocabulary", [])),
"page_count": session.get("pdf_page_count", 1),
"status": session.get("status", "unknown"),
"created_at": session.get("created_at", "").isoformat() if session.get("created_at") else None
})
return {"sessions": available, "total": len(available)}
except Exception as e:
logger.error(f"Failed to list sessions: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Backward-compat shim -- module moved to worksheet/editor_api.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.editor_api")
@@ -1,133 +1,4 @@
"""
Worksheet Editor Models — Enums, Pydantic models, and configuration.
"""
import os
import logging
from typing import Optional, List, Dict
from enum import Enum
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
# =============================================
# CONFIGURATION
# =============================================
OLLAMA_URL = os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
SD_MODEL = os.getenv("SD_MODEL", "stable-diffusion") # or specific SD model
WORKSHEET_STORAGE_DIR = os.getenv("WORKSHEET_STORAGE_DIR",
os.path.join(os.path.dirname(os.path.abspath(__file__)), "worksheet-storage"))
# Ensure storage directory exists
os.makedirs(WORKSHEET_STORAGE_DIR, exist_ok=True)
# =============================================
# ENUMS & MODELS
# =============================================
class AIImageStyle(str, Enum):
REALISTIC = "realistic"
CARTOON = "cartoon"
SKETCH = "sketch"
CLIPART = "clipart"
EDUCATIONAL = "educational"
class WorksheetStatus(str, Enum):
DRAFT = "draft"
PUBLISHED = "published"
ARCHIVED = "archived"
# Style prompt modifiers
STYLE_PROMPTS = {
AIImageStyle.REALISTIC: "photorealistic, high detail, professional photography",
AIImageStyle.CARTOON: "cartoon style, colorful, child-friendly, simple shapes",
AIImageStyle.SKETCH: "pencil sketch, hand-drawn, black and white, artistic",
AIImageStyle.CLIPART: "clipart style, flat design, simple, vector-like",
AIImageStyle.EDUCATIONAL: "educational illustration, clear, informative, textbook style"
}
# =============================================
# REQUEST/RESPONSE MODELS
# =============================================
class AIImageRequest(BaseModel):
prompt: str = Field(..., min_length=3, max_length=500)
style: AIImageStyle = AIImageStyle.EDUCATIONAL
width: int = Field(512, ge=256, le=1024)
height: int = Field(512, ge=256, le=1024)
class AIImageResponse(BaseModel):
image_base64: str
prompt_used: str
error: Optional[str] = None
class PageData(BaseModel):
id: str
index: int
canvasJSON: str
class PageFormat(BaseModel):
width: float = 210
height: float = 297
orientation: str = "portrait"
margins: Dict[str, float] = {"top": 15, "right": 15, "bottom": 15, "left": 15}
class WorksheetSaveRequest(BaseModel):
id: Optional[str] = None
title: str
description: Optional[str] = None
pages: List[PageData]
pageFormat: Optional[PageFormat] = None
class WorksheetResponse(BaseModel):
id: str
title: str
description: Optional[str]
pages: List[PageData]
pageFormat: PageFormat
createdAt: str
updatedAt: str
class AIModifyRequest(BaseModel):
prompt: str = Field(..., min_length=3, max_length=1000)
canvas_json: str
model: str = "qwen2.5vl:32b"
class AIModifyResponse(BaseModel):
modified_canvas_json: Optional[str] = None
message: str
error: Optional[str] = None
class ReconstructRequest(BaseModel):
session_id: str
page_number: int = 1
include_images: bool = True
regenerate_graphics: bool = False
class ReconstructResponse(BaseModel):
canvas_json: str
page_width: int
page_height: int
elements_count: int
vocabulary_matched: int
message: str
error: Optional[str] = None
# =============================================
# IN-MEMORY STORAGE (Development)
# =============================================
worksheets_db: Dict[str, Dict] = {}
# PDF Generation availability
try:
from reportlab.lib import colors # noqa: F401
from reportlab.lib.pagesizes import A4 # noqa: F401
from reportlab.lib.units import mm # noqa: F401
from reportlab.pdfgen import canvas # noqa: F401
from reportlab.lib.styles import getSampleStyleSheet # noqa: F401
REPORTLAB_AVAILABLE = True
except ImportError:
REPORTLAB_AVAILABLE = False
# Backward-compat shim -- module moved to worksheet/editor_models.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.editor_models")
@@ -1,255 +1,4 @@
"""
Worksheet Editor Reconstruct — Document reconstruction from vocab sessions.
"""
import io
import uuid
import base64
import logging
from typing import List, Dict
import numpy as np
from worksheet_editor_models import (
ReconstructRequest,
ReconstructResponse,
)
logger = logging.getLogger(__name__)
async def reconstruct_document_logic(request: ReconstructRequest) -> ReconstructResponse:
"""
Reconstruct a document from a vocab session into Fabric.js canvas format.
This function:
1. Loads the original PDF from the vocab session
2. Runs OCR with position tracking
3. Creates Fabric.js canvas JSON with positioned elements
4. Maps extracted vocabulary to their positions
Returns ReconstructResponse ready to send to the client.
"""
from fastapi import HTTPException
from vocab_worksheet_api import _sessions, convert_pdf_page_to_image
# Check if session exists
if request.session_id not in _sessions:
raise HTTPException(status_code=404, detail=f"Session {request.session_id} not found")
session = _sessions[request.session_id]
if not session.get("pdf_data"):
raise HTTPException(status_code=400, detail="Session has no PDF data")
pdf_data = session["pdf_data"]
page_count = session.get("pdf_page_count", 1)
if request.page_number < 1 or request.page_number > page_count:
raise HTTPException(
status_code=400,
detail=f"Page {request.page_number} not found. PDF has {page_count} pages."
)
vocabulary = session.get("vocabulary", [])
page_vocab = [v for v in vocabulary if v.get("source_page") == request.page_number]
logger.info(f"Reconstructing page {request.page_number} from session {request.session_id}")
logger.info(f"Found {len(page_vocab)} vocabulary items for this page")
image_bytes = await convert_pdf_page_to_image(pdf_data, request.page_number)
if not image_bytes:
raise HTTPException(status_code=500, detail="Failed to convert PDF page to image")
from PIL import Image
img = Image.open(io.BytesIO(image_bytes))
img_width, img_height = img.size
from hybrid_vocab_extractor import run_paddle_ocr
ocr_regions, raw_text = run_paddle_ocr(image_bytes)
logger.info(f"OCR found {len(ocr_regions)} text regions")
A4_WIDTH = 794
A4_HEIGHT = 1123
scale_x = A4_WIDTH / img_width
scale_y = A4_HEIGHT / img_height
fabric_objects = []
# 1. Add white background
fabric_objects.append({
"type": "rect", "left": 0, "top": 0,
"width": A4_WIDTH, "height": A4_HEIGHT,
"fill": "#ffffff", "selectable": False,
"evented": False, "isBackground": True
})
# 2. Group OCR regions by Y-coordinate to detect rows
sorted_regions = sorted(ocr_regions, key=lambda r: (r.y1, r.x1))
# 3. Detect headers (larger text at top)
headers = []
for region in sorted_regions:
height = region.y2 - region.y1
if region.y1 < img_height * 0.15 and height > 30:
headers.append(region)
# 4. Create text objects for each region
vocab_matched = 0
for region in sorted_regions:
left = int(region.x1 * scale_x)
top = int(region.y1 * scale_y)
is_header = region in headers
region_height = region.y2 - region.y1
base_font_size = max(10, min(32, int(region_height * scale_y * 0.8)))
if is_header:
base_font_size = max(base_font_size, 24)
is_vocab = False
vocab_match = None
for v in page_vocab:
if v.get("english", "").lower() in region.text.lower() or \
v.get("german", "").lower() in region.text.lower():
is_vocab = True
vocab_match = v
vocab_matched += 1
break
text_obj = {
"type": "i-text",
"id": f"text_{uuid.uuid4().hex[:8]}",
"left": left, "top": top,
"text": region.text,
"fontFamily": "Arial",
"fontSize": base_font_size,
"fontWeight": "bold" if is_header else "normal",
"fill": "#000000",
"originX": "left", "originY": "top",
}
if is_vocab and vocab_match:
text_obj["isVocabulary"] = True
text_obj["vocabularyId"] = vocab_match.get("id")
text_obj["english"] = vocab_match.get("english")
text_obj["german"] = vocab_match.get("german")
fabric_objects.append(text_obj)
# 5. If include_images, detect and extract image regions
if request.include_images:
image_regions = await _detect_image_regions(image_bytes, ocr_regions, img_width, img_height)
for i, img_region in enumerate(image_regions):
img_x1 = int(img_region["x1"])
img_y1 = int(img_region["y1"])
img_x2 = int(img_region["x2"])
img_y2 = int(img_region["y2"])
cropped = img.crop((img_x1, img_y1, img_x2, img_y2))
buffer = io.BytesIO()
cropped.save(buffer, format='PNG')
buffer.seek(0)
img_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
fabric_objects.append({
"type": "image",
"id": f"img_{uuid.uuid4().hex[:8]}",
"left": int(img_x1 * scale_x),
"top": int(img_y1 * scale_y),
"width": int((img_x2 - img_x1) * scale_x),
"height": int((img_y2 - img_y1) * scale_y),
"src": img_base64,
"scaleX": 1, "scaleY": 1,
})
import json
canvas_data = {
"version": "6.0.0",
"objects": fabric_objects,
"background": "#ffffff"
}
return ReconstructResponse(
canvas_json=json.dumps(canvas_data),
page_width=A4_WIDTH,
page_height=A4_HEIGHT,
elements_count=len(fabric_objects),
vocabulary_matched=vocab_matched,
message=f"Reconstructed page {request.page_number} with {len(fabric_objects)} elements, "
f"{vocab_matched} vocabulary items matched"
)
async def _detect_image_regions(
image_bytes: bytes,
ocr_regions: list,
img_width: int,
img_height: int
) -> List[Dict]:
"""
Detect image/graphic regions in the document.
Uses a simple approach:
1. Find large gaps between text regions (potential image areas)
2. Use edge detection to find bounded regions
3. Filter out text areas
"""
from PIL import Image
import cv2
try:
img = Image.open(io.BytesIO(image_bytes))
img_array = np.array(img.convert('L'))
text_mask = np.ones_like(img_array, dtype=bool)
for region in ocr_regions:
x1 = max(0, region.x1 - 5)
y1 = max(0, region.y1 - 5)
x2 = min(img_width, region.x2 + 5)
y2 = min(img_height, region.y2 + 5)
text_mask[y1:y2, x1:x2] = False
image_regions = []
edges = cv2.Canny(img_array, 50, 150)
edges[~text_mask] = 0
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if w > 50 and h > 50:
if w < img_width * 0.9 and h < img_height * 0.9:
region_content = img_array[y:y+h, x:x+w]
variance = np.var(region_content)
if variance > 500:
image_regions.append({
"x1": x, "y1": y,
"x2": x + w, "y2": y + h
})
filtered_regions = []
for region in sorted(image_regions, key=lambda r: (r["x2"]-r["x1"])*(r["y2"]-r["y1"]), reverse=True):
overlaps = False
for existing in filtered_regions:
if not (region["x2"] < existing["x1"] or region["x1"] > existing["x2"] or
region["y2"] < existing["y1"] or region["y1"] > existing["y2"]):
overlaps = True
break
if not overlaps:
filtered_regions.append(region)
logger.info(f"Detected {len(filtered_regions)} image regions")
return filtered_regions[:10]
except Exception as e:
logger.warning(f"Image region detection failed: {e}")
return []
# Backward-compat shim -- module moved to worksheet/editor_reconstruct.py
import importlib as _importlib
import sys as _sys
_sys.modules[__name__] = _importlib.import_module("worksheet.editor_reconstruct")
@@ -0,0 +1,6 @@
"""
zeugnis package — certificate crawler, models, storage.
Backward-compatible re-exports: consumers can still use
``from zeugnis_api import ...`` etc. via the shim files in backend/.
"""
+19
View File
@@ -0,0 +1,19 @@
"""
Zeugnis Rights-Aware Crawler — barrel re-export.
All implementation split into:
zeugnis_api_sources — sources, seed URLs, initialization
zeugnis_api_docs — documents, crawler, statistics, audit
FastAPI router for managing zeugnis sources, documents, and crawler operations.
"""
from fastapi import APIRouter
from .api_sources import router as _sources_router # noqa: F401
from .api_docs import router as _docs_router # noqa: F401
# Composite router (used by main.py)
router = APIRouter()
router.include_router(_sources_router)
router.include_router(_docs_router)
+321
View File
@@ -0,0 +1,321 @@
"""
Zeugnis API Docs — documents, crawler control, statistics, audit endpoints.
Extracted from zeugnis_api.py for modularity.
"""
from datetime import datetime, timedelta
from typing import Optional, List
from fastapi import APIRouter, HTTPException, BackgroundTasks, Query
from .models import (
CrawlRequest, EventType,
BUNDESLAENDER,
generate_id, get_training_allowed, get_license_for_bundesland,
)
from .crawler import (
start_crawler, stop_crawler, get_crawler_status,
)
from metrics_db import (
get_zeugnis_documents, get_zeugnis_stats,
log_zeugnis_event, get_pool,
)
router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
# =============================================================================
# Documents Endpoints
# =============================================================================
@router.get("/documents", response_model=List[dict])
async def list_documents(
bundesland: Optional[str] = None,
limit: int = Query(100, le=500),
offset: int = 0,
):
"""Get all zeugnis documents with optional filtering."""
documents = await get_zeugnis_documents(bundesland=bundesland, limit=limit, offset=offset)
return documents
@router.get("/documents/{document_id}", response_model=dict)
async def get_document(document_id: str):
"""Get details for a specific document."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
doc = await conn.fetchrow(
"""
SELECT d.*, s.bundesland, s.name as source_name
FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
JOIN zeugnis_sources s ON u.source_id = s.id
WHERE d.id = $1
""",
document_id
)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
# Log view event
await log_zeugnis_event(document_id, EventType.VIEWED.value)
return dict(doc)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/documents/{document_id}/versions", response_model=List[dict])
async def get_document_versions(document_id: str):
"""Get version history for a document."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT * FROM zeugnis_document_versions
WHERE document_id = $1
ORDER BY version DESC
""",
document_id
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Crawler Control Endpoints
# =============================================================================
@router.get("/crawler/status", response_model=dict)
async def crawler_status():
"""Get current crawler status."""
return get_crawler_status()
@router.post("/crawler/start", response_model=dict)
async def start_crawl(request: CrawlRequest, background_tasks: BackgroundTasks):
"""Start the crawler."""
success = await start_crawler(
bundesland=request.bundesland,
source_id=request.source_id,
)
if not success:
raise HTTPException(status_code=409, detail="Crawler already running")
return {"success": True, "message": "Crawler started"}
@router.post("/crawler/stop", response_model=dict)
async def stop_crawl():
"""Stop the crawler."""
success = await stop_crawler()
if not success:
raise HTTPException(status_code=409, detail="Crawler not running")
return {"success": True, "message": "Crawler stopped"}
@router.get("/crawler/queue", response_model=List[dict])
async def get_queue():
"""Get the crawler queue."""
pool = await get_pool()
if not pool:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT q.*, s.bundesland, s.name as source_name
FROM zeugnis_crawler_queue q
JOIN zeugnis_sources s ON q.source_id = s.id
ORDER BY q.priority DESC, q.created_at
"""
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/crawler/queue", response_model=dict)
async def add_to_queue(request: CrawlRequest):
"""Add a source to the crawler queue."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
queue_id = generate_id()
try:
async with pool.acquire() as conn:
# Get source ID if bundesland provided
source_id = request.source_id
if not source_id and request.bundesland:
source = await conn.fetchrow(
"SELECT id FROM zeugnis_sources WHERE bundesland = $1",
request.bundesland
)
if source:
source_id = source["id"]
if not source_id:
raise HTTPException(status_code=400, detail="Source not found")
await conn.execute(
"""
INSERT INTO zeugnis_crawler_queue (id, source_id, priority, status)
VALUES ($1, $2, $3, 'pending')
""",
queue_id, source_id, request.priority
)
return {"id": queue_id, "success": True}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Statistics Endpoints
# =============================================================================
@router.get("/stats", response_model=dict)
async def get_stats():
"""Get zeugnis crawler statistics."""
stats = await get_zeugnis_stats()
return stats
@router.get("/stats/bundesland", response_model=List[dict])
async def get_bundesland_stats():
"""Get statistics per Bundesland."""
pool = await get_pool()
# Build stats from BUNDESLAENDER with DB data if available
stats = []
for code, info in BUNDESLAENDER.items():
stat = {
"bundesland": code,
"name": info["name"],
"training_allowed": get_training_allowed(code),
"document_count": 0,
"indexed_count": 0,
"last_crawled": None,
}
if pool:
try:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT
COUNT(d.id) as doc_count,
COUNT(CASE WHEN d.indexed_in_qdrant THEN 1 END) as indexed_count,
MAX(u.last_crawled) as last_crawled
FROM zeugnis_sources s
LEFT JOIN zeugnis_seed_urls u ON s.id = u.source_id
LEFT JOIN zeugnis_documents d ON u.id = d.seed_url_id
WHERE s.bundesland = $1
GROUP BY s.id
""",
code
)
if row:
stat["document_count"] = row["doc_count"] or 0
stat["indexed_count"] = row["indexed_count"] or 0
stat["last_crawled"] = row["last_crawled"].isoformat() if row["last_crawled"] else None
except Exception:
pass
stats.append(stat)
return stats
# =============================================================================
# Audit Endpoints
# =============================================================================
@router.get("/audit/events", response_model=List[dict])
async def get_audit_events(
document_id: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = Query(100, le=1000),
days: int = Query(30, le=365),
):
"""Get audit events with optional filtering."""
pool = await get_pool()
if not pool:
return []
try:
since = datetime.now() - timedelta(days=days)
async with pool.acquire() as conn:
query = """
SELECT * FROM zeugnis_usage_events
WHERE created_at >= $1
"""
params = [since]
if document_id:
query += " AND document_id = $2"
params.append(document_id)
if event_type:
query += f" AND event_type = ${len(params) + 1}"
params.append(event_type)
query += f" ORDER BY created_at DESC LIMIT ${len(params) + 1}"
params.append(limit)
rows = await conn.fetch(query, *params)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/audit/export", response_model=dict)
async def export_audit(
days: int = Query(30, le=365),
requested_by: str = Query(..., description="User requesting the export"),
):
"""Export audit data for GDPR compliance."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
since = datetime.now() - timedelta(days=days)
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT * FROM zeugnis_usage_events
WHERE created_at >= $1
ORDER BY created_at DESC
""",
since
)
doc_count = await conn.fetchval(
"SELECT COUNT(DISTINCT document_id) FROM zeugnis_usage_events WHERE created_at >= $1",
since
)
return {
"export_date": datetime.now().isoformat(),
"requested_by": requested_by,
"events": [dict(r) for r in rows],
"document_count": doc_count or 0,
"date_range_start": since.isoformat(),
"date_range_end": datetime.now().isoformat(),
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -0,0 +1,232 @@
"""
Zeugnis API Sources — source and seed URL management endpoints.
Extracted from zeugnis_api.py for modularity.
"""
from typing import Optional, List
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from .models import (
ZeugnisSourceCreate, ZeugnisSourceVerify,
SeedUrlCreate,
LicenseType, DocType,
BUNDESLAENDER,
generate_id, get_training_allowed, get_bundesland_name, get_license_for_bundesland,
)
from metrics_db import (
get_zeugnis_sources, upsert_zeugnis_source, get_pool,
)
router = APIRouter(prefix="/api/v1/admin/zeugnis", tags=["Zeugnis Crawler"])
# =============================================================================
# Sources Endpoints
# =============================================================================
@router.get("/sources", response_model=List[dict])
async def list_sources():
"""Get all zeugnis sources (Bundeslaender)."""
sources = await get_zeugnis_sources()
if not sources:
# Return default sources if none exist
return [
{
"id": None,
"bundesland": code,
"name": info["name"],
"base_url": None,
"license_type": str(get_license_for_bundesland(code).value),
"training_allowed": get_training_allowed(code),
"verified_by": None,
"verified_at": None,
"created_at": None,
"updated_at": None,
}
for code, info in BUNDESLAENDER.items()
]
return sources
@router.post("/sources", response_model=dict)
async def create_source(source: ZeugnisSourceCreate):
"""Create or update a zeugnis source."""
source_id = generate_id()
success = await upsert_zeugnis_source(
id=source_id,
bundesland=source.bundesland,
name=source.name,
license_type=source.license_type.value,
training_allowed=source.training_allowed,
base_url=source.base_url,
)
if not success:
raise HTTPException(status_code=500, detail="Failed to create source")
return {"id": source_id, "success": True}
@router.put("/sources/{source_id}/verify", response_model=dict)
async def verify_source(source_id: str, verification: ZeugnisSourceVerify):
"""Verify a source's license status."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE zeugnis_sources
SET license_type = $2,
training_allowed = $3,
verified_by = $4,
verified_at = NOW(),
updated_at = NOW()
WHERE id = $1
""",
source_id, verification.license_type.value,
verification.training_allowed, verification.verified_by
)
return {"success": True, "source_id": source_id}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/sources/{bundesland}", response_model=dict)
async def get_source_by_bundesland(bundesland: str):
"""Get source details for a specific Bundesland."""
pool = await get_pool()
if not pool:
# Return default info
if bundesland not in BUNDESLAENDER:
raise HTTPException(status_code=404, detail=f"Bundesland not found: {bundesland}")
return {
"bundesland": bundesland,
"name": get_bundesland_name(bundesland),
"training_allowed": get_training_allowed(bundesland),
"license_type": get_license_for_bundesland(bundesland).value,
"document_count": 0,
}
try:
async with pool.acquire() as conn:
source = await conn.fetchrow(
"SELECT * FROM zeugnis_sources WHERE bundesland = $1",
bundesland
)
if source:
doc_count = await conn.fetchval(
"""
SELECT COUNT(*) FROM zeugnis_documents d
JOIN zeugnis_seed_urls u ON d.seed_url_id = u.id
WHERE u.source_id = $1
""",
source["id"]
)
return {**dict(source), "document_count": doc_count or 0}
# Return default
return {
"bundesland": bundesland,
"name": get_bundesland_name(bundesland),
"training_allowed": get_training_allowed(bundesland),
"license_type": get_license_for_bundesland(bundesland).value,
"document_count": 0,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Seed URLs Endpoints
# =============================================================================
@router.get("/sources/{source_id}/urls", response_model=List[dict])
async def list_seed_urls(source_id: str):
"""Get all seed URLs for a source."""
pool = await get_pool()
if not pool:
return []
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"SELECT * FROM zeugnis_seed_urls WHERE source_id = $1 ORDER BY created_at",
source_id
)
return [dict(r) for r in rows]
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/sources/{source_id}/urls", response_model=dict)
async def add_seed_url(source_id: str, seed_url: SeedUrlCreate):
"""Add a new seed URL to a source."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
url_id = generate_id()
try:
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status)
VALUES ($1, $2, $3, $4, 'pending')
""",
url_id, source_id, seed_url.url, seed_url.doc_type.value
)
return {"id": url_id, "success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/urls/{url_id}", response_model=dict)
async def delete_seed_url(url_id: str):
"""Delete a seed URL."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
try:
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM zeugnis_seed_urls WHERE id = $1",
url_id
)
return {"success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# =============================================================================
# Initialization Endpoint
# =============================================================================
@router.post("/init", response_model=dict)
async def initialize_sources():
"""Initialize default sources from BUNDESLAENDER."""
pool = await get_pool()
if not pool:
raise HTTPException(status_code=503, detail="Database not available")
created = 0
try:
for code, info in BUNDESLAENDER.items():
source_id = generate_id()
success = await upsert_zeugnis_source(
id=source_id,
bundesland=code,
name=info["name"],
license_type=get_license_for_bundesland(code).value,
training_allowed=get_training_allowed(code),
)
if success:
created += 1
return {"success": True, "sources_created": created}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+105
View File
@@ -0,0 +1,105 @@
"""
Zeugnis Crawler - Start/stop/status control functions.
"""
import asyncio
from typing import Optional, Dict, Any
from .worker import ZeugnisCrawler, get_crawler_state
_crawler_instance: Optional[ZeugnisCrawler] = None
_crawler_task: Optional[asyncio.Task] = None
async def start_crawler(bundesland: Optional[str] = None, source_id: Optional[str] = None) -> bool:
"""Start the crawler."""
global _crawler_instance, _crawler_task
state = get_crawler_state()
if state.is_running:
return False
state.is_running = True
state.documents_crawled_today = 0
state.documents_indexed_today = 0
state.errors_today = 0
_crawler_instance = ZeugnisCrawler()
await _crawler_instance.init()
async def run_crawler():
try:
from metrics_db import get_pool
pool = await get_pool()
if pool:
async with pool.acquire() as conn:
# Get sources to crawl
if source_id:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources WHERE id = $1",
source_id
)
elif bundesland:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources WHERE bundesland = $1",
bundesland
)
else:
sources = await conn.fetch(
"SELECT id, bundesland FROM zeugnis_sources ORDER BY bundesland"
)
for source in sources:
if not state.is_running:
break
await _crawler_instance.crawl_source(source["id"])
except Exception as e:
print(f"Crawler error: {e}")
finally:
state.is_running = False
if _crawler_instance:
await _crawler_instance.close()
_crawler_task = asyncio.create_task(run_crawler())
return True
async def stop_crawler() -> bool:
"""Stop the crawler."""
global _crawler_task
state = get_crawler_state()
if not state.is_running:
return False
state.is_running = False
if _crawler_task:
_crawler_task.cancel()
try:
await _crawler_task
except asyncio.CancelledError:
pass
return True
def get_crawler_status() -> Dict[str, Any]:
"""Get current crawler status."""
state = get_crawler_state()
return {
"is_running": state.is_running,
"current_source": state.current_source_id,
"current_bundesland": state.current_bundesland,
"queue_length": len(state.queue),
"documents_crawled_today": state.documents_crawled_today,
"documents_indexed_today": state.documents_indexed_today,
"errors_today": state.errors_today,
"last_activity": state.last_activity.isoformat() if state.last_activity else None,
}
@@ -0,0 +1,26 @@
"""
Zeugnis Rights-Aware Crawler
Barrel re-export: all public symbols for backward compatibility.
"""
from .text import ( # noqa: F401
extract_text_from_pdf,
extract_text_from_html,
chunk_text,
compute_hash,
)
from .storage import ( # noqa: F401
generate_embeddings,
upload_to_minio,
index_in_qdrant,
)
from .worker import ( # noqa: F401
CrawlerState,
ZeugnisCrawler,
)
from .control import ( # noqa: F401
start_crawler,
stop_crawler,
get_crawler_status,
)
+340
View File
@@ -0,0 +1,340 @@
"""
Zeugnis Rights-Aware Crawler - Data Models
Pydantic models for API requests/responses and internal data structures.
Database schema is defined in metrics_db.py.
"""
from datetime import datetime
from enum import Enum
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
import uuid
# =============================================================================
# Enums
# =============================================================================
class LicenseType(str, Enum):
"""License classification for training permission."""
PUBLIC_DOMAIN = "public_domain" # Amtliche Werke (§5 UrhG)
CC_BY = "cc_by" # Creative Commons Attribution
CC_BY_SA = "cc_by_sa" # CC Attribution-ShareAlike
CC_BY_NC = "cc_by_nc" # CC NonCommercial - NO TRAINING
CC_BY_NC_SA = "cc_by_nc_sa" # CC NC-SA - NO TRAINING
GOV_STATUTE_FREE_USE = "gov_statute" # Government statutes (gemeinfrei)
ALL_RIGHTS_RESERVED = "all_rights" # Standard copyright - NO TRAINING
UNKNOWN_REQUIRES_REVIEW = "unknown" # Needs manual review
class CrawlStatus(str, Enum):
"""Status of a crawl job or seed URL."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
PAUSED = "paused"
class DocType(str, Enum):
"""Type of zeugnis document."""
VERORDNUNG = "verordnung" # Official regulation
HANDREICHUNG = "handreichung" # Implementation guide
FORMULAR = "formular" # Form template
ERLASS = "erlass" # Decree
SCHULORDNUNG = "schulordnung" # School regulations
SONSTIGES = "sonstiges" # Other
class EventType(str, Enum):
"""Audit event types."""
CRAWLED = "crawled"
INDEXED = "indexed"
DOWNLOADED = "downloaded"
VIEWED = "viewed"
EXPORTED = "exported"
TRAINED_ON = "trained_on"
DELETED = "deleted"
# =============================================================================
# Bundesland Definitions
# =============================================================================
BUNDESLAENDER = {
"bw": {"name": "Baden-Württemberg", "short": "BW"},
"by": {"name": "Bayern", "short": "BY"},
"be": {"name": "Berlin", "short": "BE"},
"bb": {"name": "Brandenburg", "short": "BB"},
"hb": {"name": "Bremen", "short": "HB"},
"hh": {"name": "Hamburg", "short": "HH"},
"he": {"name": "Hessen", "short": "HE"},
"mv": {"name": "Mecklenburg-Vorpommern", "short": "MV"},
"ni": {"name": "Niedersachsen", "short": "NI"},
"nw": {"name": "Nordrhein-Westfalen", "short": "NW"},
"rp": {"name": "Rheinland-Pfalz", "short": "RP"},
"sl": {"name": "Saarland", "short": "SL"},
"sn": {"name": "Sachsen", "short": "SN"},
"st": {"name": "Sachsen-Anhalt", "short": "ST"},
"sh": {"name": "Schleswig-Holstein", "short": "SH"},
"th": {"name": "Thüringen", "short": "TH"},
}
# Training permission based on Word document analysis
TRAINING_PERMISSIONS = {
"bw": True, # Amtliches Werk
"by": True, # Amtliches Werk
"be": False, # Keine Lizenz
"bb": False, # Keine Lizenz
"hb": False, # Eingeschränkt -> False for safety
"hh": False, # Keine Lizenz
"he": True, # Amtliches Werk
"mv": False, # Eingeschränkt -> False for safety
"ni": True, # Amtliches Werk
"nw": True, # Amtliches Werk
"rp": True, # Amtliches Werk
"sl": False, # Keine Lizenz
"sn": True, # Amtliches Werk
"st": False, # Eingeschränkt -> False for safety
"sh": True, # Amtliches Werk
"th": True, # Amtliches Werk
}
# =============================================================================
# API Models - Sources
# =============================================================================
class ZeugnisSourceBase(BaseModel):
"""Base model for zeugnis source."""
bundesland: str = Field(..., description="Bundesland code (e.g., 'ni', 'by')")
name: str = Field(..., description="Full name of the source")
base_url: Optional[str] = Field(None, description="Base URL for the source")
license_type: LicenseType = Field(..., description="License classification")
training_allowed: bool = Field(False, description="Whether AI training is permitted")
class ZeugnisSourceCreate(ZeugnisSourceBase):
"""Model for creating a new source."""
pass
class ZeugnisSource(ZeugnisSourceBase):
"""Full source model with all fields."""
id: str
verified_by: Optional[str] = None
verified_at: Optional[datetime] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ZeugnisSourceVerify(BaseModel):
"""Model for verifying a source's license."""
verified_by: str = Field(..., description="User ID who verified")
license_type: LicenseType
training_allowed: bool
notes: Optional[str] = None
# =============================================================================
# API Models - Seed URLs
# =============================================================================
class SeedUrlBase(BaseModel):
"""Base model for seed URL."""
url: str = Field(..., description="URL to crawl")
doc_type: DocType = Field(DocType.VERORDNUNG, description="Type of document")
class SeedUrlCreate(SeedUrlBase):
"""Model for creating a new seed URL."""
source_id: str
class SeedUrl(SeedUrlBase):
"""Full seed URL model."""
id: str
source_id: str
status: CrawlStatus = CrawlStatus.PENDING
last_crawled: Optional[datetime] = None
error_message: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
# =============================================================================
# API Models - Documents
# =============================================================================
class ZeugnisDocumentBase(BaseModel):
"""Base model for zeugnis document."""
title: Optional[str] = None
url: str
content_type: Optional[str] = None
file_size: Optional[int] = None
class ZeugnisDocument(ZeugnisDocumentBase):
"""Full document model."""
id: str
seed_url_id: str
content_hash: Optional[str] = None
minio_path: Optional[str] = None
training_allowed: bool = False
indexed_in_qdrant: bool = False
bundesland: Optional[str] = None
source_name: Optional[str] = None
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class ZeugnisDocumentVersion(BaseModel):
"""Document version for history tracking."""
id: str
document_id: str
version: int
content_hash: str
minio_path: Optional[str] = None
change_summary: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
# =============================================================================
# API Models - Crawler
# =============================================================================
class CrawlerStatus(BaseModel):
"""Current status of the crawler."""
is_running: bool = False
current_source: Optional[str] = None
current_bundesland: Optional[str] = None
queue_length: int = 0
documents_crawled_today: int = 0
documents_indexed_today: int = 0
last_activity: Optional[datetime] = None
errors_today: int = 0
class CrawlQueueItem(BaseModel):
"""Item in the crawl queue."""
id: str
source_id: str
bundesland: str
source_name: str
priority: int = 5
status: CrawlStatus = CrawlStatus.PENDING
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
documents_found: int = 0
documents_indexed: int = 0
error_count: int = 0
created_at: datetime
class CrawlRequest(BaseModel):
"""Request to start a crawl."""
bundesland: Optional[str] = Field(None, description="Specific Bundesland to crawl")
source_id: Optional[str] = Field(None, description="Specific source ID to crawl")
priority: int = Field(5, ge=1, le=10, description="Priority (1=lowest, 10=highest)")
class CrawlResult(BaseModel):
"""Result of a crawl operation."""
source_id: str
bundesland: str
documents_found: int
documents_indexed: int
documents_skipped: int
errors: List[str]
duration_seconds: float
# =============================================================================
# API Models - Statistics
# =============================================================================
class ZeugnisStats(BaseModel):
"""Statistics for the zeugnis crawler."""
total_sources: int = 0
total_documents: int = 0
indexed_documents: int = 0
training_allowed_documents: int = 0
active_crawls: int = 0
per_bundesland: List[Dict[str, Any]] = []
class BundeslandStats(BaseModel):
"""Statistics per Bundesland."""
bundesland: str
name: str
training_allowed: bool
document_count: int
indexed_count: int
last_crawled: Optional[datetime] = None
# =============================================================================
# API Models - Audit
# =============================================================================
class UsageEvent(BaseModel):
"""Usage event for audit trail."""
id: str
document_id: str
event_type: EventType
user_id: Optional[str] = None
details: Optional[Dict[str, Any]] = None
created_at: datetime
class Config:
from_attributes = True
class AuditExport(BaseModel):
"""GDPR-compliant audit export."""
export_date: datetime
requested_by: str
events: List[UsageEvent]
document_count: int
date_range_start: datetime
date_range_end: datetime
# =============================================================================
# Helper Functions
# =============================================================================
def generate_id() -> str:
"""Generate a new UUID."""
return str(uuid.uuid4())
def get_training_allowed(bundesland: str) -> bool:
"""Get training permission for a Bundesland."""
return TRAINING_PERMISSIONS.get(bundesland.lower(), False)
def get_bundesland_name(code: str) -> str:
"""Get full Bundesland name from code."""
info = BUNDESLAENDER.get(code.lower(), {})
return info.get("name", code)
def get_license_for_bundesland(bundesland: str) -> LicenseType:
"""Get appropriate license type for a Bundesland."""
if TRAINING_PERMISSIONS.get(bundesland.lower(), False):
return LicenseType.GOV_STATUTE_FREE_USE
return LicenseType.UNKNOWN_REQUIRES_REVIEW
@@ -0,0 +1,415 @@
"""
Zeugnis Seed Data - Initial URLs from Word Document
Contains seed URLs for all 16 German federal states (Bundesländer)
based on the "Bundesland URL Zeugnisse.docx" document.
Training permissions:
- Ja: Amtliches Werk (§5 UrhG) - training allowed
- Nein: Keine Lizenz angegeben - training NOT allowed
- Eingeschränkt: Treated as NOT allowed for safety
"""
from typing import Dict, List, Any
# Seed data structure: bundesland -> list of seed URLs
SEED_DATA: Dict[str, Dict[str, Any]] = {
"bw": {
"name": "Baden-Württemberg",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://www.landesrecht-bw.de",
"urls": [
{
"url": "https://www.landesrecht-bw.de/jportal/portal/t/cru/page/bsbawueprod.psml?pid=Dokumentanzeige&showdoccase=1&js_peid=Trefferliste&documentnumber=1&numberofresults=1&fromdoctodoc=yes&doc.id=jlr-SchulGBWpP5&doc.part=X&doc.price=0.0&doc.hl=1",
"doc_type": "verordnung",
"title": "Schulgesetz BW - Zeugnisse"
},
{
"url": "https://www.landesrecht-bw.de/jportal/portal/t/cs9/page/bsbawueprod.psml?pid=Dokumentanzeige&showdoccase=1&js_peid=Trefferliste&documentnumber=1&numberofresults=1&fromdoctodoc=yes&doc.id=jlr-NotenBildVBW2016rahmen&doc.part=X&doc.price=0.0",
"doc_type": "verordnung",
"title": "Notenbildungsverordnung"
}
]
},
"by": {
"name": "Bayern",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://www.gesetze-bayern.de",
"urls": [
{
"url": "https://www.gesetze-bayern.de/Content/Document/BaySchO2016",
"doc_type": "schulordnung",
"title": "Bayerische Schulordnung"
},
{
"url": "https://www.gesetze-bayern.de/Content/Document/BayGSO",
"doc_type": "schulordnung",
"title": "Grundschulordnung Bayern"
},
{
"url": "https://www.gesetze-bayern.de/Content/Document/BayVSO",
"doc_type": "schulordnung",
"title": "Volksschulordnung Bayern"
}
]
},
"be": {
"name": "Berlin",
"license": "unknown",
"training_allowed": False,
"base_url": "https://gesetze.berlin.de",
"urls": [
{
"url": "https://gesetze.berlin.de/bsbe/document/jlr-SchulGBEpP58",
"doc_type": "verordnung",
"title": "Berliner Schulgesetz - Zeugnisse"
},
{
"url": "https://gesetze.berlin.de/bsbe/document/jlr-SekIVBE2010rahmen",
"doc_type": "verordnung",
"title": "Sekundarstufe I-Verordnung"
}
]
},
"bb": {
"name": "Brandenburg",
"license": "unknown",
"training_allowed": False,
"base_url": "https://bravors.brandenburg.de",
"urls": [
{
"url": "https://bravors.brandenburg.de/verordnungen/vvzeugnis",
"doc_type": "verordnung",
"title": "Verwaltungsvorschriften Zeugnisse"
},
{
"url": "https://bravors.brandenburg.de/verordnungen/gostv",
"doc_type": "verordnung",
"title": "GOST-Verordnung Brandenburg"
}
]
},
"hb": {
"name": "Bremen",
"license": "unknown",
"training_allowed": False, # Eingeschränkt -> False for safety
"base_url": "https://www.transparenz.bremen.de",
"urls": [
{
"url": "https://www.transparenz.bremen.de/metainformationen/bremisches-schulgesetz-bremschg-vom-28-juni-2005-121009",
"doc_type": "verordnung",
"title": "Bremisches Schulgesetz"
},
{
"url": "https://www.transparenz.bremen.de/metainformationen/verordnung-ueber-die-sekundarstufe-i-der-oberschule-vom-20-juni-2017-130380",
"doc_type": "verordnung",
"title": "Sekundarstufe I Verordnung Bremen"
}
]
},
"hh": {
"name": "Hamburg",
"license": "unknown",
"training_allowed": False,
"base_url": "https://www.landesrecht-hamburg.de",
"urls": [
{
"url": "https://www.landesrecht-hamburg.de/bsha/document/jlr-SchulGHA2009pP44",
"doc_type": "verordnung",
"title": "Hamburgisches Schulgesetz - Zeugnisse"
},
{
"url": "https://www.landesrecht-hamburg.de/bsha/document/jlr-AusglLeistVHA2011rahmen",
"doc_type": "verordnung",
"title": "Ausbildungs- und Prüfungsordnung"
}
]
},
"he": {
"name": "Hessen",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://www.rv.hessenrecht.hessen.de",
"urls": [
{
"url": "https://www.rv.hessenrecht.hessen.de/bshe/document/jlr-SchulGHE2017pP73",
"doc_type": "verordnung",
"title": "Hessisches Schulgesetz - Zeugnisse"
},
{
"url": "https://www.rv.hessenrecht.hessen.de/bshe/document/jlr-VOBGM11HE2011rahmen",
"doc_type": "verordnung",
"title": "Verordnung zur Gestaltung des Schulverhältnisses"
}
]
},
"mv": {
"name": "Mecklenburg-Vorpommern",
"license": "unknown",
"training_allowed": False, # Eingeschränkt -> False for safety
"base_url": "https://www.landesrecht-mv.de",
"urls": [
{
"url": "https://www.landesrecht-mv.de/bsmv/document/jlr-SchulGMV2010pP63",
"doc_type": "verordnung",
"title": "Schulgesetz MV - Zeugnisse"
},
{
"url": "https://www.landesrecht-mv.de/bsmv/document/jlr-ZeugnVMVrahmen",
"doc_type": "verordnung",
"title": "Zeugnisverordnung MV"
}
]
},
"ni": {
"name": "Niedersachsen",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://www.nds-voris.de",
"urls": [
{
"url": "https://www.nds-voris.de/jportal/portal/t/1gxi/page/bsvorisprod.psml?pid=Dokumentanzeige&showdoccase=1&js_peid=Trefferliste&documentnumber=1&numberofresults=1&fromdoctodoc=yes&doc.id=jlr-SchulGNDpP59",
"doc_type": "verordnung",
"title": "Niedersächsisches Schulgesetz - Zeugnisse"
},
{
"url": "https://www.nds-voris.de/jportal/portal/t/1gxi/page/bsvorisprod.psml?pid=Dokumentanzeige&showdoccase=1&js_peid=Trefferliste&documentnumber=1&numberofresults=1&fromdoctodoc=yes&doc.id=jlr-ErgZeugnErlNDrahmen",
"doc_type": "erlass",
"title": "Ergänzende Bestimmungen für Zeugnisse"
},
{
"url": "https://www.mk.niedersachsen.de/startseite/schule/unsere_schulen/allgemein_bildende_schulen/zeugnisse_versetzungen/zeugnisse-und-versetzungen-6351.html",
"doc_type": "handreichung",
"title": "Handreichung Zeugnisse NI"
}
]
},
"nw": {
"name": "Nordrhein-Westfalen",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://recht.nrw.de",
"urls": [
{
"url": "https://recht.nrw.de/lmi/owa/br_text_anzeigen?v_id=10000000000000000521",
"doc_type": "verordnung",
"title": "Schulgesetz NRW"
},
{
"url": "https://recht.nrw.de/lmi/owa/br_text_anzeigen?v_id=10000000000000000525",
"doc_type": "verordnung",
"title": "Ausbildungs- und Prüfungsordnung Sek I"
},
{
"url": "https://www.schulministerium.nrw/zeugnisse",
"doc_type": "handreichung",
"title": "Handreichung Zeugnisse NRW"
}
]
},
"rp": {
"name": "Rheinland-Pfalz",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://landesrecht.rlp.de",
"urls": [
{
"url": "https://landesrecht.rlp.de/bsrp/document/jlr-SchulGRPpP61",
"doc_type": "verordnung",
"title": "Schulgesetz RP - Zeugnisse"
},
{
"url": "https://landesrecht.rlp.de/bsrp/document/jlr-ZeugnVRPrahmen",
"doc_type": "verordnung",
"title": "Zeugnisverordnung RP"
}
]
},
"sl": {
"name": "Saarland",
"license": "unknown",
"training_allowed": False,
"base_url": "https://recht.saarland.de",
"urls": [
{
"url": "https://recht.saarland.de/bssl/document/jlr-SchulOGSLrahmen",
"doc_type": "schulordnung",
"title": "Schulordnungsgesetz Saarland"
},
{
"url": "https://recht.saarland.de/bssl/document/jlr-ZeugnVSL2014rahmen",
"doc_type": "verordnung",
"title": "Zeugnisverordnung Saarland"
}
]
},
"sn": {
"name": "Sachsen",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://www.revosax.sachsen.de",
"urls": [
{
"url": "https://www.revosax.sachsen.de/vorschrift/4192-Schulgesetz-fuer-den-Freistaat-Sachsen",
"doc_type": "verordnung",
"title": "Schulgesetz Sachsen"
},
{
"url": "https://www.revosax.sachsen.de/vorschrift/13500-Schulordnung-Gymnasien-Abiturpruefung",
"doc_type": "schulordnung",
"title": "Schulordnung Gymnasien Sachsen"
}
]
},
"st": {
"name": "Sachsen-Anhalt",
"license": "unknown",
"training_allowed": False, # Eingeschränkt -> False for safety
"base_url": "https://www.landesrecht.sachsen-anhalt.de",
"urls": [
{
"url": "https://www.landesrecht.sachsen-anhalt.de/bsst/document/jlr-SchulGSTpP27",
"doc_type": "verordnung",
"title": "Schulgesetz Sachsen-Anhalt"
},
{
"url": "https://www.landesrecht.sachsen-anhalt.de/bsst/document/jlr-VersetzVST2017rahmen",
"doc_type": "verordnung",
"title": "Versetzungsverordnung ST"
}
]
},
"sh": {
"name": "Schleswig-Holstein",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://www.gesetze-rechtsprechung.sh.juris.de",
"urls": [
{
"url": "https://www.gesetze-rechtsprechung.sh.juris.de/jportal/portal/t/10wx/page/bsshoprod.psml?pid=Dokumentanzeige&showdoccase=1&js_peid=Trefferliste&documentnumber=1&numberofresults=1&fromdoctodoc=yes&doc.id=jlr-SchulGSHpP22",
"doc_type": "verordnung",
"title": "Schulgesetz SH - Zeugnisse"
},
{
"url": "https://www.gesetze-rechtsprechung.sh.juris.de/jportal/portal/t/10wx/page/bsshoprod.psml?pid=Dokumentanzeige&showdoccase=1&js_peid=Trefferliste&documentnumber=1&numberofresults=1&fromdoctodoc=yes&doc.id=jlr-ZeugnVSHrahmen",
"doc_type": "verordnung",
"title": "Zeugnisverordnung SH"
}
]
},
"th": {
"name": "Thüringen",
"license": "gov_statute",
"training_allowed": True,
"base_url": "https://landesrecht.thueringen.de",
"urls": [
{
"url": "https://landesrecht.thueringen.de/bsth/document/jlr-SchulGTHpP58",
"doc_type": "verordnung",
"title": "Thüringer Schulgesetz - Zeugnisse"
},
{
"url": "https://landesrecht.thueringen.de/bsth/document/jlr-SchulOTH2018rahmen",
"doc_type": "schulordnung",
"title": "Thüringer Schulordnung"
}
]
}
}
async def populate_seed_data():
"""Populate database with seed data."""
from metrics_db import get_pool, upsert_zeugnis_source
from zeugnis_models import generate_id
pool = await get_pool()
if not pool:
print("Database not available")
return False
try:
async with pool.acquire() as conn:
for bundesland, data in SEED_DATA.items():
# Create or update source
source_id = generate_id()
await upsert_zeugnis_source(
id=source_id,
bundesland=bundesland,
name=data["name"],
license_type=data["license"],
training_allowed=data["training_allowed"],
base_url=data.get("base_url"),
)
# Get the actual source ID (might be existing)
existing = await conn.fetchrow(
"SELECT id FROM zeugnis_sources WHERE bundesland = $1",
bundesland
)
if existing:
source_id = existing["id"]
# Add seed URLs
for url_data in data.get("urls", []):
url_id = generate_id()
await conn.execute(
"""
INSERT INTO zeugnis_seed_urls (id, source_id, url, doc_type, status)
VALUES ($1, $2, $3, $4, 'pending')
ON CONFLICT DO NOTHING
""",
url_id, source_id, url_data["url"], url_data["doc_type"]
)
print(f"Populated {bundesland}: {len(data.get('urls', []))} URLs")
print("Seed data population complete!")
return True
except Exception as e:
print(f"Failed to populate seed data: {e}")
return False
def get_training_summary() -> Dict[str, List[str]]:
"""Get summary of training permissions."""
allowed = []
not_allowed = []
for bundesland, data in SEED_DATA.items():
name = data["name"]
if data["training_allowed"]:
allowed.append(f"{name} ({bundesland})")
else:
not_allowed.append(f"{name} ({bundesland})")
return {
"training_allowed": sorted(allowed),
"training_not_allowed": sorted(not_allowed),
"total_allowed": len(allowed),
"total_not_allowed": len(not_allowed),
}
if __name__ == "__main__":
import asyncio
print("=" * 60)
print("Zeugnis Seed Data Summary")
print("=" * 60)
summary = get_training_summary()
print(f"\nTraining ALLOWED ({summary['total_allowed']} Bundesländer):")
for bl in summary["training_allowed"]:
print(f"{bl}")
print(f"\nTraining NOT ALLOWED ({summary['total_not_allowed']} Bundesländer):")
for bl in summary["training_not_allowed"]:
print(f"{bl}")
print("\n" + "=" * 60)
print("To populate database, run:")
print(" python -c 'import asyncio; from zeugnis_seed_data import populate_seed_data; asyncio.run(populate_seed_data())'")
+180
View File
@@ -0,0 +1,180 @@
"""
Zeugnis Crawler - Embedding generation, MinIO upload, and Qdrant indexing.
"""
import io
import os
import uuid
from datetime import datetime
from typing import Optional, List, Dict, Any
# =============================================================================
# Configuration
# =============================================================================
QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
MINIO_ENDPOINT = os.getenv("MINIO_ENDPOINT", "localhost:9000")
MINIO_ACCESS_KEY = os.getenv("MINIO_ACCESS_KEY", "test-access-key")
MINIO_SECRET_KEY = os.getenv("MINIO_SECRET_KEY", "test-secret-key")
MINIO_BUCKET = os.getenv("MINIO_BUCKET", "breakpilot-rag")
EMBEDDING_BACKEND = os.getenv("EMBEDDING_BACKEND", "local")
ZEUGNIS_COLLECTION = "bp_zeugnis"
# =============================================================================
# Embedding Generation
# =============================================================================
_embedding_model = None
def get_embedding_model():
"""Get or initialize embedding model."""
global _embedding_model
if _embedding_model is None and EMBEDDING_BACKEND == "local":
try:
from sentence_transformers import SentenceTransformer
_embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
print("Loaded local embedding model: all-MiniLM-L6-v2")
except ImportError:
print("Warning: sentence-transformers not installed")
return _embedding_model
async def generate_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings for a list of texts."""
if not texts:
return []
if EMBEDDING_BACKEND == "local":
model = get_embedding_model()
if model:
embeddings = model.encode(texts, show_progress_bar=False)
return [emb.tolist() for emb in embeddings]
return []
elif EMBEDDING_BACKEND == "openai":
import openai
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
print("Warning: OPENAI_API_KEY not set")
return []
client = openai.AsyncOpenAI(api_key=api_key)
response = await client.embeddings.create(
input=texts,
model="text-embedding-3-small"
)
return [item.embedding for item in response.data]
return []
# =============================================================================
# MinIO Storage
# =============================================================================
async def upload_to_minio(
content: bytes,
bundesland: str,
filename: str,
content_type: str = "application/pdf",
year: Optional[int] = None,
) -> Optional[str]:
"""Upload document to MinIO."""
try:
from minio import Minio
client = Minio(
MINIO_ENDPOINT,
access_key=MINIO_ACCESS_KEY,
secret_key=MINIO_SECRET_KEY,
secure=os.getenv("MINIO_SECURE", "false").lower() == "true"
)
# Ensure bucket exists
if not client.bucket_exists(MINIO_BUCKET):
client.make_bucket(MINIO_BUCKET)
# Build path
year_str = str(year) if year else str(datetime.now().year)
object_name = f"landes-daten/{bundesland}/zeugnis/{year_str}/{filename}"
# Upload
client.put_object(
MINIO_BUCKET,
object_name,
io.BytesIO(content),
len(content),
content_type=content_type,
)
return object_name
except Exception as e:
print(f"MinIO upload failed: {e}")
return None
# =============================================================================
# Qdrant Indexing
# =============================================================================
async def index_in_qdrant(
doc_id: str,
chunks: List[str],
embeddings: List[List[float]],
metadata: Dict[str, Any],
) -> int:
"""Index document chunks in Qdrant."""
try:
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
client = QdrantClient(url=QDRANT_URL)
# Ensure collection exists
collections = client.get_collections().collections
if not any(c.name == ZEUGNIS_COLLECTION for c in collections):
vector_size = len(embeddings[0]) if embeddings else 384
client.create_collection(
collection_name=ZEUGNIS_COLLECTION,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE,
),
)
print(f"Created Qdrant collection: {ZEUGNIS_COLLECTION}")
# Create points
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point_id = str(uuid.uuid4())
points.append(PointStruct(
id=point_id,
vector=embedding,
payload={
"document_id": doc_id,
"chunk_index": i,
"chunk_text": chunk[:500], # Store first 500 chars for preview
"bundesland": metadata.get("bundesland"),
"doc_type": metadata.get("doc_type"),
"title": metadata.get("title"),
"source_url": metadata.get("url"),
"training_allowed": metadata.get("training_allowed", False),
"indexed_at": datetime.now().isoformat(),
}
))
# Upsert
if points:
client.upsert(
collection_name=ZEUGNIS_COLLECTION,
points=points,
)
return len(points)
except Exception as e:
print(f"Qdrant indexing failed: {e}")
return 0
+110
View File
@@ -0,0 +1,110 @@
"""
Zeugnis Crawler - Text extraction, chunking, and hashing utilities.
"""
import hashlib
from typing import List
CHUNK_SIZE = 1000
CHUNK_OVERLAP = 200
def extract_text_from_pdf(content: bytes) -> str:
"""Extract text from PDF bytes."""
try:
from PyPDF2 import PdfReader
import io
reader = PdfReader(io.BytesIO(content))
text_parts = []
for page in reader.pages:
text = page.extract_text()
if text:
text_parts.append(text)
return "\n\n".join(text_parts)
except Exception as e:
print(f"PDF extraction failed: {e}")
return ""
def extract_text_from_html(content: bytes, encoding: str = "utf-8") -> str:
"""Extract text from HTML bytes."""
try:
from bs4 import BeautifulSoup
html = content.decode(encoding, errors="replace")
soup = BeautifulSoup(html, "html.parser")
# Remove script and style elements
for element in soup(["script", "style", "nav", "header", "footer"]):
element.decompose()
# Get text
text = soup.get_text(separator="\n", strip=True)
# Clean up whitespace
lines = [line.strip() for line in text.splitlines() if line.strip()]
return "\n".join(lines)
except Exception as e:
print(f"HTML extraction failed: {e}")
return ""
def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]:
"""Split text into overlapping chunks."""
if not text:
return []
chunks = []
separators = ["\n\n", "\n", ". ", " "]
def split_recursive(text: str, sep_index: int = 0) -> List[str]:
if len(text) <= chunk_size:
return [text] if text.strip() else []
if sep_index >= len(separators):
# Force split at chunk_size
result = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if chunk.strip():
result.append(chunk)
return result
sep = separators[sep_index]
parts = text.split(sep)
result = []
current = ""
for part in parts:
if len(current) + len(sep) + len(part) <= chunk_size:
current = current + sep + part if current else part
else:
if current.strip():
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
current = part
if current.strip():
result.extend(split_recursive(current, sep_index + 1) if len(current) > chunk_size else [current])
return result
chunks = split_recursive(text)
# Add overlap
if overlap > 0 and len(chunks) > 1:
overlapped = []
for i, chunk in enumerate(chunks):
if i > 0:
# Add end of previous chunk
prev_end = chunks[i - 1][-overlap:]
chunk = prev_end + chunk
overlapped.append(chunk)
chunks = overlapped
return chunks
def compute_hash(content: bytes) -> str:
"""Compute SHA-256 hash of content."""
return hashlib.sha256(content).hexdigest()

Some files were not shown because too many files have changed in this diff Show More