fix: Restore all files lost during destructive rebase

A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.

This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).

Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Benjamin Admin
2026-02-09 09:51:32 +01:00
parent f7487ee240
commit 21a844cb8a
1986 changed files with 744143 additions and 1731 deletions

View File

@@ -0,0 +1 @@
# BYOEH Test Suite

View File

@@ -0,0 +1,14 @@
"""
Pytest configuration for klausur-service tests.
Ensures local modules (hyde, hybrid_search, rag_evaluation, etc.)
can be imported by adding the backend directory to sys.path.
"""
import sys
from pathlib import Path
# Add the backend directory to sys.path so local modules can be imported
backend_dir = Path(__file__).parent.parent
if str(backend_dir) not in sys.path:
sys.path.insert(0, str(backend_dir))

View File

@@ -0,0 +1,769 @@
"""
Tests for Advanced RAG Features
Tests for the newly implemented RAG quality improvements:
- HyDE (Hypothetical Document Embeddings)
- Hybrid Search (Dense + Sparse/BM25)
- RAG Evaluation (RAGAS-inspired metrics)
- PDF Extraction (Unstructured.io, PyMuPDF, PyPDF2)
- Self-RAG / Corrective RAG
Run with: pytest tests/test_advanced_rag.py -v
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import asyncio
# =============================================================================
# HyDE Tests
# =============================================================================
class TestHyDE:
"""Tests for HyDE (Hypothetical Document Embeddings) module."""
def test_hyde_config(self):
"""Test HyDE configuration loading."""
from hyde import HYDE_ENABLED, HYDE_LLM_BACKEND, HYDE_MODEL, get_hyde_info
info = get_hyde_info()
assert "enabled" in info
assert "llm_backend" in info
assert "model" in info
def test_hyde_prompt_template(self):
"""Test that HyDE prompt template is properly formatted."""
from hyde import HYDE_PROMPT_TEMPLATE
assert "{query}" in HYDE_PROMPT_TEMPLATE
assert "Erwartungshorizont" in HYDE_PROMPT_TEMPLATE
assert "Bildungsstandards" in HYDE_PROMPT_TEMPLATE
@pytest.mark.asyncio
async def test_generate_hypothetical_document_disabled(self):
"""Test HyDE returns original query when disabled."""
from hyde import generate_hypothetical_document
with patch('hyde.HYDE_ENABLED', False):
result = await generate_hypothetical_document("Test query")
assert result == "Test query"
@pytest.mark.asyncio
async def test_hyde_search_fallback(self):
"""Test hyde_search falls back gracefully without LLM."""
from hyde import hyde_search
async def mock_search(query, **kwargs):
return [{"id": "1", "text": "Result for: " + query}]
with patch('hyde.HYDE_ENABLED', False):
result = await hyde_search(
query="Test query",
search_func=mock_search,
)
assert result["hyde_used"] is False
# When disabled, original_query should match the query
assert result["original_query"] == "Test query"
# =============================================================================
# Hybrid Search Tests
# =============================================================================
class TestHybridSearch:
"""Tests for Hybrid Search (Dense + Sparse) module."""
def test_hybrid_search_config(self):
"""Test hybrid search configuration."""
from hybrid_search import HYBRID_ENABLED, DENSE_WEIGHT, SPARSE_WEIGHT, get_hybrid_search_info
info = get_hybrid_search_info()
assert "enabled" in info
assert "dense_weight" in info
assert "sparse_weight" in info
assert info["dense_weight"] + info["sparse_weight"] == pytest.approx(1.0)
def test_bm25_tokenization(self):
"""Test BM25 German tokenization."""
from hybrid_search import BM25
bm25 = BM25()
tokens = bm25._tokenize("Der Erwartungshorizont für die Abiturprüfung")
# German stopwords should be removed
assert "der" not in tokens
assert "für" not in tokens
assert "die" not in tokens
# Content words should remain
assert "erwartungshorizont" in tokens
assert "abiturprüfung" in tokens
def test_bm25_stopwords(self):
"""Test that German stopwords are defined."""
from hybrid_search import GERMAN_STOPWORDS
assert "der" in GERMAN_STOPWORDS
assert "die" in GERMAN_STOPWORDS
assert "und" in GERMAN_STOPWORDS
assert "ist" in GERMAN_STOPWORDS
assert len(GERMAN_STOPWORDS) > 50
def test_bm25_fit_and_search(self):
"""Test BM25 fitting and searching."""
from hybrid_search import BM25
documents = [
"Der Erwartungshorizont für Mathematik enthält Bewertungskriterien.",
"Die Gedichtanalyse erfordert formale und inhaltliche Aspekte.",
"Biologie Klausur Bewertung nach Anforderungsbereichen.",
]
bm25 = BM25()
bm25.fit(documents)
assert bm25.N == 3
assert len(bm25.corpus) == 3
results = bm25.search("Mathematik Bewertungskriterien", top_k=2)
assert len(results) == 2
assert results[0][0] == 0 # First document should rank highest
def test_normalize_scores(self):
"""Test score normalization."""
from hybrid_search import normalize_scores
scores = [0.5, 1.0, 0.0, 0.75]
normalized = normalize_scores(scores)
assert min(normalized) == 0.0
assert max(normalized) == 1.0
assert len(normalized) == len(scores)
def test_normalize_scores_empty(self):
"""Test normalization with empty list."""
from hybrid_search import normalize_scores
assert normalize_scores([]) == []
def test_normalize_scores_same_value(self):
"""Test normalization when all scores are the same."""
from hybrid_search import normalize_scores
scores = [0.5, 0.5, 0.5]
normalized = normalize_scores(scores)
assert all(s == 1.0 for s in normalized)
# =============================================================================
# RAG Evaluation Tests
# =============================================================================
class TestRAGEvaluation:
"""Tests for RAG Evaluation (RAGAS-inspired) module."""
def test_evaluation_config(self):
"""Test RAG evaluation configuration."""
from rag_evaluation import EVALUATION_ENABLED, EVAL_MODEL, get_evaluation_info
info = get_evaluation_info()
assert "enabled" in info
assert "metrics" in info
assert "context_precision" in info["metrics"]
assert "faithfulness" in info["metrics"]
def test_text_similarity(self):
"""Test Jaccard text similarity."""
from rag_evaluation import _text_similarity
# Same text
sim1 = _text_similarity("Hello world", "Hello world")
assert sim1 == 1.0
# No overlap
sim2 = _text_similarity("Hello world", "Goodbye universe")
assert sim2 == 0.0
# Partial overlap
sim3 = _text_similarity("Hello beautiful world", "Hello cruel world")
assert 0 < sim3 < 1
def test_context_precision(self):
"""Test context precision calculation."""
from rag_evaluation import calculate_context_precision
retrieved = ["Doc A about topic X", "Doc B about topic Y"]
relevant = ["Doc A about topic X"]
precision = calculate_context_precision("query", retrieved, relevant)
assert precision == 0.5 # 1 out of 2 retrieved is relevant
def test_context_precision_empty_retrieved(self):
"""Test precision with no retrieved documents."""
from rag_evaluation import calculate_context_precision
precision = calculate_context_precision("query", [], ["relevant"])
assert precision == 0.0
def test_context_recall(self):
"""Test context recall calculation."""
from rag_evaluation import calculate_context_recall
retrieved = ["Doc A about topic X"]
relevant = ["Doc A about topic X", "Doc B about topic Y"]
recall = calculate_context_recall("query", retrieved, relevant)
assert recall == 0.5 # Found 1 out of 2 relevant
def test_context_recall_no_relevant(self):
"""Test recall when there are no relevant documents."""
from rag_evaluation import calculate_context_recall
recall = calculate_context_recall("query", ["something"], [])
assert recall == 1.0 # Nothing to miss
def test_load_save_eval_results(self):
"""Test evaluation results file operations."""
from rag_evaluation import _load_eval_results, _save_eval_results
from pathlib import Path
import tempfile
import os
# Use temp file
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
temp_path = Path(f.name) # Convert to Path object
try:
with patch('rag_evaluation.EVAL_RESULTS_FILE', temp_path):
# Save some results
test_results = [{"test": "data", "score": 0.8}]
_save_eval_results(test_results)
# Load them back
loaded = _load_eval_results()
assert loaded == test_results
finally:
os.unlink(temp_path)
# =============================================================================
# PDF Extraction Tests
# =============================================================================
class TestPDFExtraction:
"""Tests for PDF Extraction module."""
def test_pdf_extraction_config(self):
"""Test PDF extraction configuration."""
from pdf_extraction import PDF_BACKEND, get_pdf_extraction_info
info = get_pdf_extraction_info()
assert "configured_backend" in info
assert "available_backends" in info
assert "recommended" in info
def test_detect_available_backends(self):
"""Test backend detection."""
from pdf_extraction import _detect_available_backends
backends = _detect_available_backends()
assert isinstance(backends, list)
# In Docker container, at least pypdf (BSD) or unstructured (Apache 2.0) should be available
# In local test environment without dependencies, list may be empty
# NOTE: PyMuPDF (AGPL) is NOT installed by default for license compliance
if backends:
# If any backend is found, verify it's one of the license-compliant options
for backend in backends:
assert backend in ["pypdf", "unstructured", "pymupdf"]
def test_pdf_extraction_result_class(self):
"""Test PDFExtractionResult data class."""
from pdf_extraction import PDFExtractionResult
result = PDFExtractionResult(
text="Extracted text",
backend_used="pypdf",
pages=5,
elements=[{"type": "paragraph"}],
tables=[{"text": "table data"}],
metadata={"key": "value"},
)
assert result.text == "Extracted text"
assert result.backend_used == "pypdf"
assert result.pages == 5
assert len(result.elements) == 1
assert len(result.tables) == 1
# Test to_dict
d = result.to_dict()
assert d["text"] == "Extracted text"
assert d["element_count"] == 1
assert d["table_count"] == 1
def test_pdf_extraction_error(self):
"""Test PDF extraction error handling."""
from pdf_extraction import PDFExtractionError
with pytest.raises(PDFExtractionError):
raise PDFExtractionError("Test error")
@pytest.mark.xfail(reason="_extract_with_pypdf is internal function not exposed in API")
def test_pypdf_extraction(self):
"""Test pypdf extraction with a simple PDF (BSD-3-Clause licensed)."""
from pdf_extraction import _extract_with_pypdf, PDFExtractionError
# Create a minimal valid PDF
# This is a very simple PDF that PyPDF2 can parse
simple_pdf = b"""%PDF-1.4
1 0 obj
<< /Type /Catalog /Pages 2 0 R >>
endobj
2 0 obj
<< /Type /Pages /Kids [3 0 R] /Count 1 >>
endobj
3 0 obj
<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] /Contents 4 0 R >>
endobj
4 0 obj
<< /Length 44 >>
stream
BT
/F1 12 Tf
100 700 Td
(Hello World) Tj
ET
endstream
endobj
xref
0 5
0000000000 65535 f
0000000009 00000 n
0000000058 00000 n
0000000115 00000 n
0000000206 00000 n
trailer
<< /Size 5 /Root 1 0 R >>
startxref
300
%%EOF"""
# This may fail because the PDF is too minimal, but tests the code path
try:
result = _extract_with_pypdf(simple_pdf)
assert result.backend_used == "pypdf"
except PDFExtractionError:
# Expected for minimal PDF
pass
# =============================================================================
# Self-RAG Tests
# =============================================================================
class TestSelfRAG:
"""Tests for Self-RAG / Corrective RAG module."""
def test_self_rag_config(self):
"""Test Self-RAG configuration."""
from self_rag import (
SELF_RAG_ENABLED, RELEVANCE_THRESHOLD, GROUNDING_THRESHOLD,
MAX_RETRIEVAL_ATTEMPTS, get_self_rag_info
)
info = get_self_rag_info()
assert "enabled" in info
assert "relevance_threshold" in info
assert "grounding_threshold" in info
assert "max_retrieval_attempts" in info
assert "features" in info
def test_retrieval_decision_enum(self):
"""Test RetrievalDecision enum."""
from self_rag import RetrievalDecision
assert RetrievalDecision.SUFFICIENT.value == "sufficient"
assert RetrievalDecision.NEEDS_MORE.value == "needs_more"
assert RetrievalDecision.REFORMULATE.value == "reformulate"
assert RetrievalDecision.FALLBACK.value == "fallback"
@pytest.mark.asyncio
async def test_grade_document_relevance_no_llm(self):
"""Test document grading without LLM (keyword-based)."""
from self_rag import grade_document_relevance
with patch('self_rag.OPENAI_API_KEY', ''):
score, reason = await grade_document_relevance(
"Mathematik Bewertungskriterien",
"Der Erwartungshorizont für Mathematik enthält klare Bewertungskriterien."
)
assert 0 <= score <= 1
assert "Keyword" in reason
@pytest.mark.asyncio
async def test_decide_retrieval_strategy_empty_docs(self):
"""Test retrieval decision with no documents."""
from self_rag import decide_retrieval_strategy, RetrievalDecision
decision, meta = await decide_retrieval_strategy("query", [], attempt=1)
assert decision == RetrievalDecision.REFORMULATE
assert "No documents" in meta.get("reason", "")
@pytest.mark.asyncio
async def test_decide_retrieval_strategy_max_attempts(self):
"""Test retrieval decision at max attempts."""
from self_rag import decide_retrieval_strategy, RetrievalDecision, MAX_RETRIEVAL_ATTEMPTS
decision, meta = await decide_retrieval_strategy(
"query", [], attempt=MAX_RETRIEVAL_ATTEMPTS
)
assert decision == RetrievalDecision.FALLBACK
@pytest.mark.asyncio
async def test_reformulate_query_no_llm(self):
"""Test query reformulation without LLM."""
from self_rag import reformulate_query
with patch('self_rag.OPENAI_API_KEY', ''):
result = await reformulate_query("EA Mathematik Anforderungen")
# Should expand abbreviations
assert "erhöhtes Anforderungsniveau" in result or result == "EA Mathematik Anforderungen"
@pytest.mark.asyncio
async def test_filter_relevant_documents(self):
"""Test document filtering by relevance."""
from self_rag import filter_relevant_documents
docs = [
{"text": "Mathematik Bewertungskriterien für Abitur"},
{"text": "Rezept für Schokoladenkuchen"},
{"text": "Erwartungshorizont Mathematik eA"},
]
with patch('self_rag.OPENAI_API_KEY', ''):
relevant, filtered = await filter_relevant_documents(
"Mathematik Abitur Bewertung",
docs,
threshold=0.1 # Low threshold for keyword matching
)
# All docs should have relevance_score added
for doc in relevant + filtered:
assert "relevance_score" in doc
@pytest.mark.asyncio
async def test_self_rag_retrieve_disabled(self):
"""Test self_rag_retrieve when disabled."""
from self_rag import self_rag_retrieve
async def mock_search(query, **kwargs):
return [{"id": "1", "text": "Result"}]
with patch('self_rag.SELF_RAG_ENABLED', False):
result = await self_rag_retrieve(
query="Test query",
search_func=mock_search,
)
assert result["self_rag_enabled"] is False
assert len(result["results"]) == 1
# =============================================================================
# Integration Tests - Module Availability
# =============================================================================
class TestModuleAvailability:
"""Test that all advanced RAG modules are properly importable."""
def test_hyde_import(self):
"""Test HyDE module import."""
from hyde import (
generate_hypothetical_document,
hyde_search,
get_hyde_info,
HYDE_ENABLED,
)
assert callable(generate_hypothetical_document)
assert callable(hyde_search)
def test_hybrid_search_import(self):
"""Test Hybrid Search module import."""
from hybrid_search import (
BM25,
hybrid_search,
get_hybrid_search_info,
HYBRID_ENABLED,
)
assert callable(hybrid_search)
assert BM25 is not None
def test_rag_evaluation_import(self):
"""Test RAG Evaluation module import."""
from rag_evaluation import (
calculate_context_precision,
calculate_context_recall,
evaluate_faithfulness,
evaluate_answer_relevancy,
evaluate_rag_response,
get_evaluation_info,
)
assert callable(calculate_context_precision)
assert callable(evaluate_rag_response)
def test_pdf_extraction_import(self):
"""Test PDF Extraction module import."""
from pdf_extraction import (
extract_text_from_pdf,
extract_text_from_pdf_enhanced,
get_pdf_extraction_info,
PDFExtractionResult,
)
assert callable(extract_text_from_pdf)
assert callable(extract_text_from_pdf_enhanced)
def test_self_rag_import(self):
"""Test Self-RAG module import."""
from self_rag import (
grade_document_relevance,
filter_relevant_documents,
self_rag_retrieve,
get_self_rag_info,
RetrievalDecision,
)
assert callable(self_rag_retrieve)
assert RetrievalDecision is not None
# =============================================================================
# End-to-End Feature Verification
# =============================================================================
class TestFeatureVerification:
"""Verify that all features are properly configured and usable."""
def test_all_features_have_info_endpoints(self):
"""Test that all features provide info functions."""
from hyde import get_hyde_info
from hybrid_search import get_hybrid_search_info
from rag_evaluation import get_evaluation_info
from pdf_extraction import get_pdf_extraction_info
from self_rag import get_self_rag_info
infos = [
get_hyde_info(),
get_hybrid_search_info(),
get_evaluation_info(),
get_pdf_extraction_info(),
get_self_rag_info(),
]
for info in infos:
assert isinstance(info, dict)
# Each should have an "enabled" or similar status field
assert any(k in info for k in ["enabled", "configured_backend", "available_backends"])
def test_environment_variables_documented(self):
"""Test that all environment variables are accessible."""
import os
# These env vars should be used by the modules
env_vars = [
"HYDE_ENABLED",
"HYDE_LLM_BACKEND",
"HYBRID_SEARCH_ENABLED",
"HYBRID_DENSE_WEIGHT",
"RAG_EVALUATION_ENABLED",
"PDF_EXTRACTION_BACKEND",
"SELF_RAG_ENABLED",
]
# Just verify they're readable (will use defaults if not set)
for var in env_vars:
os.getenv(var, "default") # Should not raise
# =============================================================================
# Admin API Tests (RAG Documentation with HTML rendering)
# =============================================================================
class TestRAGAdminAPI:
"""Tests for RAG Admin API endpoints."""
@pytest.mark.xfail(reason="get_rag_documentation not yet implemented - Backlog item")
@pytest.mark.asyncio
async def test_rag_documentation_markdown_format(self):
"""Test RAG documentation endpoint returns markdown."""
from admin_api import get_rag_documentation
result = await get_rag_documentation(format="markdown")
assert result["format"] == "markdown"
assert "content" in result
assert result["status"] in ["success", "inline"]
@pytest.mark.xfail(reason="get_rag_documentation not yet implemented - Backlog item")
@pytest.mark.asyncio
async def test_rag_documentation_html_format(self):
"""Test RAG documentation endpoint returns HTML with tables."""
from admin_api import get_rag_documentation
result = await get_rag_documentation(format="html")
assert result["format"] == "html"
assert "content" in result
# HTML should contain proper table styling
html = result["content"]
assert "<table" in html
assert "<th>" in html or "<td>" in html
assert "<style>" in html
assert "border-collapse" in html
@pytest.mark.xfail(reason="get_rag_system_info not yet implemented - Backlog item")
@pytest.mark.asyncio
async def test_rag_system_info_has_feature_status(self):
"""Test RAG system-info includes feature status."""
from admin_api import get_rag_system_info
result = await get_rag_system_info()
# Check feature status structure
assert "feature_status" in result.__dict__ or hasattr(result, 'feature_status')
@pytest.mark.xfail(reason="get_rag_system_info not yet implemented - Backlog item")
@pytest.mark.asyncio
async def test_rag_system_info_has_privacy_notes(self):
"""Test RAG system-info includes privacy notes."""
from admin_api import get_rag_system_info
result = await get_rag_system_info()
assert hasattr(result, 'privacy_notes') or "privacy_notes" in str(result)
# =============================================================================
# Reranker Tests
# =============================================================================
class TestReranker:
"""Tests for Re-Ranker module."""
def test_reranker_config(self):
"""Test reranker configuration."""
from reranker import RERANKER_BACKEND, LOCAL_RERANKER_MODEL, get_reranker_info
info = get_reranker_info()
assert "backend" in info
assert "model" in info
# Note: key is "embedding_service_available" not "available"
assert "embedding_service_available" in info
def test_reranker_model_license(self):
"""Test that default reranker model is Apache 2.0 licensed."""
from reranker import LOCAL_RERANKER_MODEL
# BAAI/bge-reranker-v2-m3 is Apache 2.0 licensed
assert "bge-reranker" in LOCAL_RERANKER_MODEL
# Should NOT use MS MARCO models (non-commercial license)
assert "ms-marco" not in LOCAL_RERANKER_MODEL.lower()
@pytest.mark.xfail(reason="rerank_results not yet implemented - async wrapper planned")
@pytest.mark.asyncio
async def test_rerank_results(self):
"""Test reranking of results."""
from reranker import rerank_results
results = [
{"text": "Mathematik Bewertungskriterien", "score": 0.5},
{"text": "Rezept fuer Kuchen", "score": 0.6},
{"text": "Erwartungshorizont Mathe Abitur", "score": 0.4},
]
reranked = await rerank_results(
query="Mathematik Abitur Bewertung",
results=results,
top_k=2,
)
# Should return top_k results
assert len(reranked) <= 2
# Results should have rerank score
for r in reranked:
assert "rerank_score" in r or "score" in r
# =============================================================================
# API Integration Tests (EH RAG Query with rerank param)
# =============================================================================
class TestEHRAGQueryAPI:
"""Tests for EH RAG Query API with advanced features."""
def test_eh_rag_query_params(self):
"""Test that RAG query accepts rerank parameter."""
# This is a structural test - verify the API model accepts the param
from pydantic import BaseModel
from typing import Optional
# Mock the expected request model
class RAGQueryRequest(BaseModel):
query_text: str
passphrase: str
subject: Optional[str] = None
limit: int = 5
rerank: bool = True # New param
# Should not raise validation error
req = RAGQueryRequest(
query_text="Test",
passphrase="secret",
rerank=True,
)
assert req.rerank is True
def test_rag_result_has_search_info(self):
"""Test that RAG result model supports search_info."""
from pydantic import BaseModel
from typing import Optional, List, Dict, Any
# Mock the expected response model
class RAGSource(BaseModel):
text: str
score: float
reranked: Optional[bool] = None
class RAGResult(BaseModel):
context: str
sources: List[RAGSource]
query: str
search_info: Optional[Dict[str, Any]] = None
# Should support search_info
result = RAGResult(
context="Test context",
sources=[RAGSource(text="Test", score=0.8, reranked=True)],
query="Test query",
search_info={
"rerank_applied": True,
"hybrid_search_applied": True,
"total_candidates": 20,
"embedding_model": "BAAI/bge-m3",
}
)
assert result.search_info["rerank_applied"] is True
assert result.sources[0].reranked is True
# =============================================================================
# Run Tests
# =============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,937 @@
"""
Unit Tests for BYOEH (Bring-Your-Own-Expectation-Horizon) Module
Tests cover:
- EH upload and storage
- Key sharing system
- Invitation flow (Invite, Accept, Decline, Revoke)
- Klausur linking
- RAG query functionality
- Audit logging
"""
import pytest
import json
import uuid
import hashlib
from datetime import datetime, timezone, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
# Import the main app and data structures (post-refactoring modular imports)
import sys
sys.path.insert(0, '..')
from main import app
from storage import (
eh_db,
eh_key_shares_db,
eh_klausur_links_db,
eh_audit_db,
eh_invitations_db,
klausuren_db,
)
from models.eh import (
Erwartungshorizont,
EHKeyShare,
EHKlausurLink,
EHShareInvitation,
)
from models.exam import Klausur
from models.enums import KlausurModus
# =============================================
# FIXTURES
# =============================================
@pytest.fixture
def client():
"""Test client for FastAPI app."""
return TestClient(app)
@pytest.fixture
def auth_headers():
"""JWT auth headers for teacher user."""
import jwt
token = jwt.encode(
{
"user_id": "test-teacher-001",
"email": "teacher@school.de",
"role": "admin",
"tenant_id": "school-001"
},
"your-super-secret-jwt-key-change-in-production",
algorithm="HS256"
)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def second_examiner_headers():
"""JWT auth headers for second examiner (non-admin teacher)."""
import jwt
token = jwt.encode(
{
"user_id": "test-examiner-002",
"email": "examiner2@school.de",
"role": "teacher", # Non-admin to test access control
"tenant_id": "school-001"
},
"your-super-secret-jwt-key-change-in-production",
algorithm="HS256"
)
return {"Authorization": f"Bearer {token}"}
@pytest.fixture
def sample_eh():
"""Create a sample Erwartungshorizont."""
eh_id = str(uuid.uuid4())
eh = Erwartungshorizont(
id=eh_id,
tenant_id="school-001",
teacher_id="test-teacher-001",
title="Deutsch LK Abitur 2025",
subject="deutsch",
niveau="eA",
year=2025,
aufgaben_nummer="Aufgabe 1",
encryption_key_hash="abc123" + "0" * 58, # 64 char hash
salt="def456" * 5 + "0" * 2, # 32 char salt
encrypted_file_path=f"/app/eh-uploads/school-001/{eh_id}/encrypted.bin",
file_size_bytes=1024000,
original_filename="erwartungshorizont.pdf",
rights_confirmed=True,
rights_confirmed_at=datetime.now(timezone.utc),
status="indexed",
chunk_count=10,
indexed_at=datetime.now(timezone.utc),
error_message=None,
training_allowed=False,
created_at=datetime.now(timezone.utc),
deleted_at=None
)
eh_db[eh_id] = eh
return eh
@pytest.fixture
def sample_klausur():
"""Create a sample Klausur."""
klausur_id = str(uuid.uuid4())
klausur = Klausur(
id=klausur_id,
title="Deutsch LK Q1",
subject="deutsch",
modus=KlausurModus.VORABITUR,
class_id="class-001",
year=2025,
semester="Q1",
erwartungshorizont=None,
students=[],
created_at=datetime.now(timezone.utc),
teacher_id="test-teacher-001"
)
klausuren_db[klausur_id] = klausur
return klausur
@pytest.fixture(autouse=True)
def cleanup():
"""Clean up databases after each test."""
yield
eh_db.clear()
eh_key_shares_db.clear()
eh_klausur_links_db.clear()
eh_audit_db.clear()
eh_invitations_db.clear()
klausuren_db.clear()
@pytest.fixture
def sample_invitation(sample_eh):
"""Create a sample invitation."""
invitation_id = str(uuid.uuid4())
invitation = EHShareInvitation(
id=invitation_id,
eh_id=sample_eh.id,
inviter_id="test-teacher-001",
invitee_id="",
invitee_email="examiner2@school.de",
role="second_examiner",
klausur_id=None,
message="Bitte EH fuer Zweitkorrektur nutzen",
status="pending",
expires_at=datetime.now(timezone.utc) + timedelta(days=14),
created_at=datetime.now(timezone.utc),
accepted_at=None,
declined_at=None
)
eh_invitations_db[invitation_id] = invitation
return invitation
# =============================================
# EH CRUD TESTS
# =============================================
class TestEHList:
"""Tests for GET /api/v1/eh"""
def test_list_empty(self, client, auth_headers):
"""List returns empty when no EH exist."""
response = client.get("/api/v1/eh", headers=auth_headers)
assert response.status_code == 200
assert response.json() == []
def test_list_own_eh(self, client, auth_headers, sample_eh):
"""List returns only user's own EH."""
response = client.get("/api/v1/eh", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["id"] == sample_eh.id
assert data[0]["title"] == sample_eh.title
def test_list_filter_by_subject(self, client, auth_headers, sample_eh):
"""List can filter by subject."""
response = client.get("/api/v1/eh?subject=deutsch", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 1
response = client.get("/api/v1/eh?subject=englisch", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 0
def test_list_filter_by_year(self, client, auth_headers, sample_eh):
"""List can filter by year."""
response = client.get("/api/v1/eh?year=2025", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 1
response = client.get("/api/v1/eh?year=2024", headers=auth_headers)
assert response.status_code == 200
assert len(response.json()) == 0
class TestEHGet:
"""Tests for GET /api/v1/eh/{id}"""
def test_get_existing_eh(self, client, auth_headers, sample_eh):
"""Get returns EH details."""
response = client.get(f"/api/v1/eh/{sample_eh.id}", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["id"] == sample_eh.id
assert data["title"] == sample_eh.title
assert data["subject"] == sample_eh.subject
def test_get_nonexistent_eh(self, client, auth_headers):
"""Get returns 404 for non-existent EH."""
response = client.get(f"/api/v1/eh/{uuid.uuid4()}", headers=auth_headers)
assert response.status_code == 404
class TestEHDelete:
"""Tests for DELETE /api/v1/eh/{id}"""
def test_delete_own_eh(self, client, auth_headers, sample_eh):
"""Owner can delete their EH."""
response = client.delete(f"/api/v1/eh/{sample_eh.id}", headers=auth_headers)
assert response.status_code == 200
assert response.json()["status"] == "deleted"
# Verify soft delete
assert eh_db[sample_eh.id].deleted_at is not None
def test_delete_others_eh(self, client, second_examiner_headers, sample_eh):
"""Non-owner cannot delete EH."""
response = client.delete(f"/api/v1/eh/{sample_eh.id}", headers=second_examiner_headers)
assert response.status_code == 403
def test_delete_nonexistent_eh(self, client, auth_headers):
"""Delete returns 404 for non-existent EH."""
response = client.delete(f"/api/v1/eh/{uuid.uuid4()}", headers=auth_headers)
assert response.status_code == 404
# =============================================
# KEY SHARING TESTS
# =============================================
class TestEHSharing:
"""Tests for EH key sharing system."""
def test_share_eh_with_examiner(self, client, auth_headers, sample_eh):
"""Owner can share EH with another examiner."""
response = client.post(
f"/api/v1/eh/{sample_eh.id}/share",
headers=auth_headers,
json={
"user_id": "test-examiner-002",
"role": "second_examiner",
"encrypted_passphrase": "encrypted-secret-123",
"passphrase_hint": "Das uebliche Passwort"
}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "shared"
assert data["shared_with"] == "test-examiner-002"
assert data["role"] == "second_examiner"
def test_share_invalid_role(self, client, auth_headers, sample_eh):
"""Sharing with invalid role fails."""
response = client.post(
f"/api/v1/eh/{sample_eh.id}/share",
headers=auth_headers,
json={
"user_id": "test-examiner-002",
"role": "invalid_role",
"encrypted_passphrase": "encrypted-secret-123"
}
)
assert response.status_code == 400
def test_share_others_eh_fails(self, client, second_examiner_headers, sample_eh):
"""Non-owner cannot share EH."""
response = client.post(
f"/api/v1/eh/{sample_eh.id}/share",
headers=second_examiner_headers,
json={
"user_id": "test-examiner-003",
"role": "third_examiner",
"encrypted_passphrase": "encrypted-secret-123"
}
)
assert response.status_code == 403
def test_list_shares(self, client, auth_headers, sample_eh):
"""Owner can list shares."""
# Create a share first
share = EHKeyShare(
id=str(uuid.uuid4()),
eh_id=sample_eh.id,
user_id="test-examiner-002",
encrypted_passphrase="encrypted",
passphrase_hint="hint",
granted_by="test-teacher-001",
granted_at=datetime.now(timezone.utc),
role="second_examiner",
klausur_id=None,
active=True
)
eh_key_shares_db[sample_eh.id] = [share]
response = client.get(f"/api/v1/eh/{sample_eh.id}/shares", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["user_id"] == "test-examiner-002"
def test_revoke_share(self, client, auth_headers, sample_eh):
"""Owner can revoke a share."""
share_id = str(uuid.uuid4())
share = EHKeyShare(
id=share_id,
eh_id=sample_eh.id,
user_id="test-examiner-002",
encrypted_passphrase="encrypted",
passphrase_hint="hint",
granted_by="test-teacher-001",
granted_at=datetime.now(timezone.utc),
role="second_examiner",
klausur_id=None,
active=True
)
eh_key_shares_db[sample_eh.id] = [share]
response = client.delete(
f"/api/v1/eh/{sample_eh.id}/shares/{share_id}",
headers=auth_headers
)
assert response.status_code == 200
assert response.json()["status"] == "revoked"
# Verify share is inactive
assert not eh_key_shares_db[sample_eh.id][0].active
def test_get_shared_with_me(self, client, second_examiner_headers, sample_eh):
"""User can see EH shared with them."""
share = EHKeyShare(
id=str(uuid.uuid4()),
eh_id=sample_eh.id,
user_id="test-examiner-002",
encrypted_passphrase="encrypted",
passphrase_hint="hint",
granted_by="test-teacher-001",
granted_at=datetime.now(timezone.utc),
role="second_examiner",
klausur_id=None,
active=True
)
eh_key_shares_db[sample_eh.id] = [share]
response = client.get("/api/v1/eh/shared-with-me", headers=second_examiner_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["eh"]["id"] == sample_eh.id
# =============================================
# KLAUSUR LINKING TESTS
# =============================================
class TestEHKlausurLinking:
"""Tests for EH-Klausur linking."""
def test_link_eh_to_klausur(self, client, auth_headers, sample_eh, sample_klausur):
"""Owner can link EH to Klausur."""
response = client.post(
f"/api/v1/eh/{sample_eh.id}/link-klausur",
headers=auth_headers,
json={"klausur_id": sample_klausur.id}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "linked"
assert data["eh_id"] == sample_eh.id
assert data["klausur_id"] == sample_klausur.id
def test_get_linked_eh(self, client, auth_headers, sample_eh, sample_klausur):
"""Get linked EH for a Klausur."""
link = EHKlausurLink(
id=str(uuid.uuid4()),
eh_id=sample_eh.id,
klausur_id=sample_klausur.id,
linked_by="test-teacher-001",
linked_at=datetime.now(timezone.utc)
)
eh_klausur_links_db[sample_klausur.id] = [link]
response = client.get(
f"/api/v1/klausuren/{sample_klausur.id}/linked-eh",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["eh"]["id"] == sample_eh.id
assert data[0]["is_owner"] is True
def test_unlink_eh_from_klausur(self, client, auth_headers, sample_eh, sample_klausur):
"""Owner can unlink EH from Klausur."""
link = EHKlausurLink(
id=str(uuid.uuid4()),
eh_id=sample_eh.id,
klausur_id=sample_klausur.id,
linked_by="test-teacher-001",
linked_at=datetime.now(timezone.utc)
)
eh_klausur_links_db[sample_klausur.id] = [link]
response = client.delete(
f"/api/v1/eh/{sample_eh.id}/link-klausur/{sample_klausur.id}",
headers=auth_headers
)
assert response.status_code == 200
assert response.json()["status"] == "unlinked"
# =============================================
# AUDIT LOG TESTS
# =============================================
class TestAuditLog:
"""Tests for audit logging."""
def test_audit_log_on_share(self, client, auth_headers, sample_eh):
"""Sharing creates audit log entry."""
initial_count = len(eh_audit_db)
client.post(
f"/api/v1/eh/{sample_eh.id}/share",
headers=auth_headers,
json={
"user_id": "test-examiner-002",
"role": "second_examiner",
"encrypted_passphrase": "encrypted-secret-123"
}
)
assert len(eh_audit_db) > initial_count
latest = eh_audit_db[-1]
assert latest.action == "share"
assert latest.eh_id == sample_eh.id
def test_get_audit_log(self, client, auth_headers, sample_eh):
"""Can retrieve audit log."""
# Create some audit entries by sharing
client.post(
f"/api/v1/eh/{sample_eh.id}/share",
headers=auth_headers,
json={
"user_id": "test-examiner-002",
"role": "second_examiner",
"encrypted_passphrase": "encrypted-secret-123"
}
)
response = client.get(
f"/api/v1/eh/audit-log?eh_id={sample_eh.id}",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert len(data) > 0
# =============================================
# RIGHTS TEXT TESTS
# =============================================
class TestRightsText:
"""Tests for rights confirmation text."""
def test_get_rights_text(self, client, auth_headers):
"""Can retrieve rights confirmation text."""
response = client.get("/api/v1/eh/rights-text", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert "text" in data
assert "version" in data
assert "Urheberrecht" in data["text"]
# =============================================
# ENCRYPTION SERVICE TESTS
# =============================================
class TestEncryptionUtils:
"""Tests for encryption utilities (eh_pipeline.py)."""
def test_hash_key(self):
"""Key hashing produces consistent results."""
from eh_pipeline import hash_key
import os
passphrase = "test-secret-passphrase"
salt_hex = os.urandom(16).hex()
hash1 = hash_key(passphrase, salt_hex)
hash2 = hash_key(passphrase, salt_hex)
assert hash1 == hash2
assert len(hash1) == 64 # SHA-256 hex
def test_verify_key_hash(self):
"""Key hash verification works correctly."""
from eh_pipeline import hash_key, verify_key_hash
import os
passphrase = "test-secret-passphrase"
salt_hex = os.urandom(16).hex()
key_hash = hash_key(passphrase, salt_hex)
assert verify_key_hash(passphrase, salt_hex, key_hash) is True
assert verify_key_hash("wrong-passphrase", salt_hex, key_hash) is False
def test_chunk_text(self):
"""Text chunking produces correct overlap."""
from eh_pipeline import chunk_text
text = "A" * 2000 # 2000 characters
chunks = chunk_text(text, chunk_size=1000, overlap=200)
assert len(chunks) >= 2
# Check overlap
assert chunks[0][-200:] == chunks[1][:200]
def test_encrypt_decrypt_text(self):
"""Text encryption and decryption round-trip."""
from eh_pipeline import encrypt_text, decrypt_text
plaintext = "Dies ist ein geheimer Text."
passphrase = "geheim123"
salt = "a" * 32 # 32 hex chars = 16 bytes
encrypted = encrypt_text(plaintext, passphrase, salt)
decrypted = decrypt_text(encrypted, passphrase, salt)
assert decrypted == plaintext
# =============================================
# INVITATION FLOW TESTS
# =============================================
class TestEHInvitationFlow:
"""Tests for the Invite/Accept/Decline/Revoke workflow."""
def test_invite_to_eh(self, client, auth_headers, sample_eh):
"""Owner can send invitation to share EH."""
response = client.post(
f"/api/v1/eh/{sample_eh.id}/invite",
headers=auth_headers,
json={
"invitee_email": "zweitkorrektor@school.de",
"role": "second_examiner",
"message": "Bitte EH fuer Zweitkorrektur nutzen",
"expires_in_days": 14
}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "invited"
assert data["invitee_email"] == "zweitkorrektor@school.de"
assert data["role"] == "second_examiner"
assert "invitation_id" in data
assert "expires_at" in data
def test_invite_invalid_role(self, client, auth_headers, sample_eh):
"""Invitation with invalid role fails."""
response = client.post(
f"/api/v1/eh/{sample_eh.id}/invite",
headers=auth_headers,
json={
"invitee_email": "zweitkorrektor@school.de",
"role": "invalid_role"
}
)
assert response.status_code == 400
def test_invite_by_non_owner_fails(self, client, second_examiner_headers, sample_eh):
"""Non-owner cannot send invitation."""
response = client.post(
f"/api/v1/eh/{sample_eh.id}/invite",
headers=second_examiner_headers,
json={
"invitee_email": "drittkorrektor@school.de",
"role": "third_examiner"
}
)
assert response.status_code == 403
def test_duplicate_pending_invitation_fails(self, client, auth_headers, sample_eh, sample_invitation):
"""Cannot send duplicate pending invitation to same user."""
response = client.post(
f"/api/v1/eh/{sample_eh.id}/invite",
headers=auth_headers,
json={
"invitee_email": "examiner2@school.de", # Same email as sample_invitation
"role": "second_examiner"
}
)
assert response.status_code == 409 # Conflict
def test_list_pending_invitations(self, client, second_examiner_headers, sample_eh, sample_invitation):
"""User can see pending invitations addressed to them."""
response = client.get("/api/v1/eh/invitations/pending", headers=second_examiner_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["invitation"]["id"] == sample_invitation.id
assert data[0]["eh"]["title"] == sample_eh.title
def test_list_sent_invitations(self, client, auth_headers, sample_eh, sample_invitation):
"""Inviter can see sent invitations."""
response = client.get("/api/v1/eh/invitations/sent", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["invitation"]["id"] == sample_invitation.id
def test_accept_invitation(self, client, second_examiner_headers, sample_eh, sample_invitation):
"""Invitee can accept invitation and get access."""
response = client.post(
f"/api/v1/eh/invitations/{sample_invitation.id}/accept",
headers=second_examiner_headers,
json={"encrypted_passphrase": "encrypted-secret-key-for-zk"}
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "accepted"
assert data["eh_id"] == sample_eh.id
assert "share_id" in data
# Verify invitation status updated
assert eh_invitations_db[sample_invitation.id].status == "accepted"
# Verify key share created
assert sample_eh.id in eh_key_shares_db
assert len(eh_key_shares_db[sample_eh.id]) == 1
def test_accept_invitation_wrong_user(self, client, auth_headers, sample_eh, sample_invitation):
"""Only invitee can accept invitation."""
# auth_headers is for teacher, not the invitee
response = client.post(
f"/api/v1/eh/invitations/{sample_invitation.id}/accept",
headers=auth_headers,
json={"encrypted_passphrase": "encrypted-secret"}
)
assert response.status_code == 403
def test_accept_expired_invitation(self, client, second_examiner_headers, sample_eh):
"""Cannot accept expired invitation."""
# Create expired invitation
invitation_id = str(uuid.uuid4())
expired_invitation = EHShareInvitation(
id=invitation_id,
eh_id=sample_eh.id,
inviter_id="test-teacher-001",
invitee_id="",
invitee_email="examiner2@school.de",
role="second_examiner",
klausur_id=None,
message=None,
status="pending",
expires_at=datetime.now(timezone.utc) - timedelta(days=1), # Expired
created_at=datetime.now(timezone.utc) - timedelta(days=15),
accepted_at=None,
declined_at=None
)
eh_invitations_db[invitation_id] = expired_invitation
response = client.post(
f"/api/v1/eh/invitations/{invitation_id}/accept",
headers=second_examiner_headers,
json={"encrypted_passphrase": "encrypted-secret"}
)
assert response.status_code == 400
assert "expired" in response.json()["detail"].lower()
def test_decline_invitation(self, client, second_examiner_headers, sample_invitation):
"""Invitee can decline invitation."""
response = client.post(
f"/api/v1/eh/invitations/{sample_invitation.id}/decline",
headers=second_examiner_headers
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "declined"
# Verify status updated
assert eh_invitations_db[sample_invitation.id].status == "declined"
def test_decline_invitation_wrong_user(self, client, auth_headers, sample_invitation):
"""Only invitee can decline invitation."""
response = client.post(
f"/api/v1/eh/invitations/{sample_invitation.id}/decline",
headers=auth_headers
)
assert response.status_code == 403
def test_revoke_invitation(self, client, auth_headers, sample_invitation):
"""Inviter can revoke pending invitation."""
response = client.delete(
f"/api/v1/eh/invitations/{sample_invitation.id}",
headers=auth_headers
)
assert response.status_code == 200
data = response.json()
assert data["status"] == "revoked"
# Verify status updated
assert eh_invitations_db[sample_invitation.id].status == "revoked"
def test_revoke_invitation_wrong_user(self, client, second_examiner_headers, sample_invitation):
"""Only inviter can revoke invitation."""
response = client.delete(
f"/api/v1/eh/invitations/{sample_invitation.id}",
headers=second_examiner_headers
)
assert response.status_code == 403
def test_revoke_non_pending_invitation(self, client, auth_headers, sample_eh):
"""Cannot revoke already accepted invitation."""
invitation_id = str(uuid.uuid4())
accepted_invitation = EHShareInvitation(
id=invitation_id,
eh_id=sample_eh.id,
inviter_id="test-teacher-001",
invitee_id="test-examiner-002",
invitee_email="examiner2@school.de",
role="second_examiner",
klausur_id=None,
message=None,
status="accepted",
expires_at=datetime.now(timezone.utc) + timedelta(days=14),
created_at=datetime.now(timezone.utc),
accepted_at=datetime.now(timezone.utc),
declined_at=None
)
eh_invitations_db[invitation_id] = accepted_invitation
response = client.delete(
f"/api/v1/eh/invitations/{invitation_id}",
headers=auth_headers
)
assert response.status_code == 400
def test_get_access_chain(self, client, auth_headers, sample_eh, sample_invitation):
"""Owner can see complete access chain."""
# Add a key share
share = EHKeyShare(
id=str(uuid.uuid4()),
eh_id=sample_eh.id,
user_id="test-examiner-003",
encrypted_passphrase="encrypted",
passphrase_hint="",
granted_by="test-teacher-001",
granted_at=datetime.now(timezone.utc),
role="third_examiner",
klausur_id=None,
active=True
)
eh_key_shares_db[sample_eh.id] = [share]
response = client.get(f"/api/v1/eh/{sample_eh.id}/access-chain", headers=auth_headers)
assert response.status_code == 200
data = response.json()
assert data["eh_id"] == sample_eh.id
assert data["owner"]["user_id"] == "test-teacher-001"
assert len(data["active_shares"]) == 1
assert len(data["pending_invitations"]) == 1 # sample_invitation
class TestInvitationWorkflow:
"""Integration tests for complete invitation workflow."""
def test_complete_invite_accept_workflow(
self, client, auth_headers, second_examiner_headers, sample_eh, sample_klausur
):
"""Test complete workflow: invite -> accept -> access."""
# 1. Owner invites ZK
response = client.post(
f"/api/v1/eh/{sample_eh.id}/invite",
headers=auth_headers,
json={
"invitee_email": "examiner2@school.de",
"role": "second_examiner",
"klausur_id": sample_klausur.id,
"message": "Bitte fuer Zweitkorrektur nutzen"
}
)
assert response.status_code == 200
invitation_id = response.json()["invitation_id"]
# 2. ZK sees pending invitation
response = client.get("/api/v1/eh/invitations/pending", headers=second_examiner_headers)
assert response.status_code == 200
pending = response.json()
assert len(pending) == 1
assert pending[0]["invitation"]["id"] == invitation_id
# 3. ZK accepts invitation
response = client.post(
f"/api/v1/eh/invitations/{invitation_id}/accept",
headers=second_examiner_headers,
json={"encrypted_passphrase": "encrypted-key-for-zk"}
)
assert response.status_code == 200
# 4. ZK can now see EH in shared list
response = client.get("/api/v1/eh/shared-with-me", headers=second_examiner_headers)
assert response.status_code == 200
shared = response.json()
assert len(shared) == 1
assert shared[0]["eh"]["id"] == sample_eh.id
# 5. EK sees invitation as accepted
response = client.get("/api/v1/eh/invitations/sent", headers=auth_headers)
assert response.status_code == 200
sent = response.json()
assert len(sent) == 1
assert sent[0]["invitation"]["status"] == "accepted"
def test_invite_decline_reinvite_workflow(
self, client, auth_headers, second_examiner_headers, sample_eh
):
"""Test workflow: invite -> decline -> re-invite."""
# 1. Owner invites ZK
response = client.post(
f"/api/v1/eh/{sample_eh.id}/invite",
headers=auth_headers,
json={
"invitee_email": "examiner2@school.de",
"role": "second_examiner"
}
)
assert response.status_code == 200
invitation_id = response.json()["invitation_id"]
# 2. ZK declines
response = client.post(
f"/api/v1/eh/invitations/{invitation_id}/decline",
headers=second_examiner_headers
)
assert response.status_code == 200
# 3. Owner can send new invitation (declined invitation doesn't block)
response = client.post(
f"/api/v1/eh/{sample_eh.id}/invite",
headers=auth_headers,
json={
"invitee_email": "examiner2@school.de",
"role": "second_examiner",
"message": "Zweiter Versuch"
}
)
assert response.status_code == 200 # New invitation allowed
# =============================================
# INTEGRATION TESTS
# =============================================
class TestEHWorkflow:
"""Integration tests for complete EH workflow."""
def test_complete_sharing_workflow(
self, client, auth_headers, second_examiner_headers, sample_eh, sample_klausur
):
"""Test complete workflow: upload -> link -> share -> access."""
# 1. Link EH to Klausur
response = client.post(
f"/api/v1/eh/{sample_eh.id}/link-klausur",
headers=auth_headers,
json={"klausur_id": sample_klausur.id}
)
assert response.status_code == 200
# 2. Share with second examiner
response = client.post(
f"/api/v1/eh/{sample_eh.id}/share",
headers=auth_headers,
json={
"user_id": "test-examiner-002",
"role": "second_examiner",
"encrypted_passphrase": "encrypted-secret",
"klausur_id": sample_klausur.id
}
)
assert response.status_code == 200
# 3. Second examiner can see shared EH
response = client.get("/api/v1/eh/shared-with-me", headers=second_examiner_headers)
assert response.status_code == 200
shared = response.json()
assert len(shared) == 1
assert shared[0]["eh"]["id"] == sample_eh.id
# 4. Second examiner can see linked EH for Klausur
response = client.get(
f"/api/v1/klausuren/{sample_klausur.id}/linked-eh",
headers=second_examiner_headers
)
assert response.status_code == 200
linked = response.json()
assert len(linked) == 1
assert linked[0]["is_owner"] is False
assert linked[0]["share"] is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,623 @@
"""
Tests for Legal Templates RAG System.
Tests template_sources.py, github_crawler.py, legal_templates_ingestion.py,
and the admin API endpoints for legal templates.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
import json
# =============================================================================
# Template Sources Tests
# =============================================================================
class TestLicenseType:
"""Tests for LicenseType enum."""
def test_license_types_exist(self):
"""Test that all expected license types are defined."""
from template_sources import LicenseType
assert LicenseType.PUBLIC_DOMAIN.value == "public_domain"
assert LicenseType.CC0.value == "cc0"
assert LicenseType.UNLICENSE.value == "unlicense"
assert LicenseType.MIT.value == "mit"
assert LicenseType.CC_BY_4.value == "cc_by_4"
assert LicenseType.REUSE_NOTICE.value == "reuse_notice"
class TestLicenseInfo:
"""Tests for LicenseInfo dataclass."""
def test_license_info_creation(self):
"""Test creating a LicenseInfo instance."""
from template_sources import LicenseInfo, LicenseType
info = LicenseInfo(
id=LicenseType.CC0,
name="CC0 1.0 Universal",
url="https://creativecommons.org/publicdomain/zero/1.0/",
attribution_required=False,
)
assert info.id == LicenseType.CC0
assert info.attribution_required is False
assert info.training_allowed is True
assert info.output_allowed is True
def test_get_attribution_text_no_attribution(self):
"""Test attribution text when not required."""
from template_sources import LicenseInfo, LicenseType
info = LicenseInfo(
id=LicenseType.CC0,
name="CC0",
url="https://example.com",
attribution_required=False,
)
result = info.get_attribution_text("Test Source", "https://test.com")
assert result == ""
def test_get_attribution_text_with_template(self):
"""Test attribution text with template."""
from template_sources import LicenseInfo, LicenseType
info = LicenseInfo(
id=LicenseType.MIT,
name="MIT License",
url="https://opensource.org/licenses/MIT",
attribution_required=True,
attribution_template="Based on [{source_name}]({source_url}) - MIT License",
)
result = info.get_attribution_text("Test Source", "https://test.com")
assert "Test Source" in result
assert "https://test.com" in result
class TestSourceConfig:
"""Tests for SourceConfig dataclass."""
def test_source_config_creation(self):
"""Test creating a SourceConfig instance."""
from template_sources import SourceConfig, LicenseType
source = SourceConfig(
name="test-source",
license_type=LicenseType.CC0,
template_types=["privacy_policy", "terms_of_service"],
languages=["de", "en"],
jurisdiction="DE",
description="Test description",
repo_url="https://github.com/test/repo",
)
assert source.name == "test-source"
assert source.license_type == LicenseType.CC0
assert "privacy_policy" in source.template_types
assert source.enabled is True
def test_source_config_license_info(self):
"""Test getting license info from source config."""
from template_sources import SourceConfig, LicenseType, LICENSES
source = SourceConfig(
name="test-source",
license_type=LicenseType.MIT,
template_types=["privacy_policy"],
languages=["en"],
jurisdiction="US",
description="Test",
)
info = source.license_info
assert info.id == LicenseType.MIT
assert info.attribution_required is True
class TestTemplateSources:
"""Tests for TEMPLATE_SOURCES list."""
def test_template_sources_not_empty(self):
"""Test that template sources are defined."""
from template_sources import TEMPLATE_SOURCES
assert len(TEMPLATE_SOURCES) > 0
def test_github_site_policy_exists(self):
"""Test that github-site-policy source exists."""
from template_sources import TEMPLATE_SOURCES
source = next((s for s in TEMPLATE_SOURCES if s.name == "github-site-policy"), None)
assert source is not None
assert source.repo_url == "https://github.com/github/site-policy"
def test_enabled_sources(self):
"""Test getting enabled sources."""
from template_sources import get_enabled_sources
enabled = get_enabled_sources()
assert all(s.enabled for s in enabled)
def test_sources_by_priority(self):
"""Test getting sources by priority."""
from template_sources import get_sources_by_priority
# Priority 1 sources only
p1 = get_sources_by_priority(1)
assert all(s.priority == 1 for s in p1)
# Priority 1-2 sources
p2 = get_sources_by_priority(2)
assert all(s.priority <= 2 for s in p2)
def test_sources_by_license(self):
"""Test getting sources by license type."""
from template_sources import get_sources_by_license, LicenseType
cc0_sources = get_sources_by_license(LicenseType.CC0)
assert all(s.license_type == LicenseType.CC0 for s in cc0_sources)
# =============================================================================
# GitHub Crawler Tests
# =============================================================================
class TestMarkdownParser:
"""Tests for MarkdownParser class."""
def test_parse_simple_markdown(self):
"""Test parsing simple markdown content."""
from github_crawler import MarkdownParser
content = """# Test Title
This is some content.
## Section 1
More content here.
"""
doc = MarkdownParser.parse(content, "test.md")
assert doc.title == "Test Title"
assert doc.file_type == "markdown"
assert "content" in doc.text
def test_extract_title_from_heading(self):
"""Test extracting title from h1 heading."""
from github_crawler import MarkdownParser
title = MarkdownParser._extract_title("# My Document\n\nContent", "fallback.md")
assert title == "My Document"
def test_extract_title_fallback(self):
"""Test fallback to filename when no heading."""
from github_crawler import MarkdownParser
title = MarkdownParser._extract_title("No heading here", "my-document.md")
assert title == "My Document"
def test_detect_german_language(self):
"""Test German language detection."""
from github_crawler import MarkdownParser
german_text = "Dies ist eine Datenschutzerklaerung fuer die Verarbeitung personenbezogener Daten."
lang = MarkdownParser._detect_language(german_text)
assert lang == "de"
def test_detect_english_language(self):
"""Test English language detection."""
from github_crawler import MarkdownParser
english_text = "This is a privacy policy for processing personal data in our application."
lang = MarkdownParser._detect_language(english_text)
assert lang == "en"
def test_find_placeholders(self):
"""Test finding placeholder patterns."""
from github_crawler import MarkdownParser
content = "Company: [COMPANY_NAME], Contact: {email}, Address: __ADDRESS__"
placeholders = MarkdownParser._find_placeholders(content)
assert "[COMPANY_NAME]" in placeholders
assert "{email}" in placeholders
assert "__ADDRESS__" in placeholders
class TestHTMLParser:
"""Tests for HTMLParser class."""
def test_parse_simple_html(self):
"""Test parsing simple HTML content."""
from github_crawler import HTMLParser
content = """<!DOCTYPE html>
<html>
<head><title>Test Page</title></head>
<body>
<h1>Welcome</h1>
<p>This is content.</p>
</body>
</html>"""
doc = HTMLParser.parse(content, "test.html")
assert doc.title == "Test Page"
assert doc.file_type == "html"
assert "Welcome" in doc.text
assert "content" in doc.text
def test_html_to_text_removes_scripts(self):
"""Test that scripts are removed from HTML."""
from github_crawler import HTMLParser
html = "<p>Text</p><script>alert('bad');</script><p>More</p>"
text = HTMLParser._html_to_text(html)
assert "alert" not in text
assert "Text" in text
assert "More" in text
class TestJSONParser:
"""Tests for JSONParser class."""
def test_parse_simple_json(self):
"""Test parsing simple JSON content."""
from github_crawler import JSONParser
content = json.dumps({
"title": "Privacy Policy",
"text": "This is the privacy policy content.",
"language": "en",
})
docs = JSONParser.parse(content, "policy.json")
assert len(docs) == 1
assert docs[0].title == "Privacy Policy"
assert "privacy policy content" in docs[0].text
def test_parse_nested_json(self):
"""Test parsing nested JSON structures."""
from github_crawler import JSONParser
content = json.dumps({
"sections": {
"intro": {"title": "Introduction", "text": "Welcome text"},
"data": {"title": "Data Collection", "text": "Collection info"},
}
})
docs = JSONParser.parse(content, "nested.json")
# Should extract nested documents
assert len(docs) >= 2
class TestExtractedDocument:
"""Tests for ExtractedDocument dataclass."""
def test_extracted_document_hash(self):
"""Test that source hash is auto-generated."""
from github_crawler import ExtractedDocument
doc = ExtractedDocument(
text="Some content",
title="Test",
file_path="test.md",
file_type="markdown",
source_url="https://example.com",
)
assert doc.source_hash != ""
assert len(doc.source_hash) == 64 # SHA256 hex
# =============================================================================
# Legal Templates Ingestion Tests
# =============================================================================
class TestLegalTemplatesIngestion:
"""Tests for LegalTemplatesIngestion class."""
@pytest.fixture
def mock_qdrant(self):
"""Mock Qdrant client."""
with patch('legal_templates_ingestion.QdrantClient') as mock:
client = MagicMock()
client.get_collections.return_value.collections = []
mock.return_value = client
yield client
@pytest.fixture
def mock_http_client(self):
"""Mock HTTP client for embeddings."""
with patch('legal_templates_ingestion.httpx.AsyncClient') as mock:
client = AsyncMock()
mock.return_value = client
yield client
def test_chunk_text_short(self):
"""Test chunking short text."""
from legal_templates_ingestion import LegalTemplatesIngestion
with patch('legal_templates_ingestion.QdrantClient'):
ingestion = LegalTemplatesIngestion()
chunks = ingestion._chunk_text("Short text", chunk_size=1000)
assert len(chunks) == 1
assert chunks[0] == "Short text"
def test_chunk_text_long(self):
"""Test chunking long text."""
from legal_templates_ingestion import LegalTemplatesIngestion
with patch('legal_templates_ingestion.QdrantClient'):
ingestion = LegalTemplatesIngestion()
# Create text longer than chunk size
long_text = "This is a sentence. " * 100
chunks = ingestion._chunk_text(long_text, chunk_size=200, overlap=50)
assert len(chunks) > 1
# Each chunk should be roughly chunk_size
for chunk in chunks:
assert len(chunk) <= 250 # Allow some buffer
def test_split_sentences(self):
"""Test German sentence splitting."""
from legal_templates_ingestion import LegalTemplatesIngestion
with patch('legal_templates_ingestion.QdrantClient'):
ingestion = LegalTemplatesIngestion()
text = "Dies ist Satz eins. Dies ist Satz zwei. Und Satz drei."
sentences = ingestion._split_sentences(text)
assert len(sentences) == 3
def test_split_sentences_preserves_abbreviations(self):
"""Test that abbreviations don't split sentences."""
from legal_templates_ingestion import LegalTemplatesIngestion
with patch('legal_templates_ingestion.QdrantClient'):
ingestion = LegalTemplatesIngestion()
text = "Das ist z.B. ein Beispiel. Und noch ein Satz."
sentences = ingestion._split_sentences(text)
assert len(sentences) == 2
assert "z.B." in sentences[0] or "z.b." in sentences[0].lower()
def test_infer_template_type_privacy(self):
"""Test inferring privacy policy type."""
from legal_templates_ingestion import LegalTemplatesIngestion
from github_crawler import ExtractedDocument
from template_sources import SourceConfig, LicenseType
with patch('legal_templates_ingestion.QdrantClient'):
ingestion = LegalTemplatesIngestion()
doc = ExtractedDocument(
text="Diese Datenschutzerklaerung informiert Sie ueber die Verarbeitung personenbezogener Daten.",
title="Datenschutz",
file_path="privacy.md",
file_type="markdown",
source_url="https://example.com",
)
source = SourceConfig(
name="test",
license_type=LicenseType.CC0,
template_types=["privacy_policy"],
languages=["de"],
jurisdiction="DE",
description="Test",
)
template_type = ingestion._infer_template_type(doc, source)
assert template_type == "privacy_policy"
def test_infer_clause_category(self):
"""Test inferring clause category."""
from legal_templates_ingestion import LegalTemplatesIngestion
with patch('legal_templates_ingestion.QdrantClient'):
ingestion = LegalTemplatesIngestion()
# Test liability clause
text = "Die Haftung des Anbieters ist auf grobe Fahrlässigkeit beschränkt."
category = ingestion._infer_clause_category(text)
assert category == "haftung"
# Test privacy clause
text = "Wir verarbeiten personenbezogene Daten gemäß der DSGVO."
category = ingestion._infer_clause_category(text)
assert category == "datenschutz"
# =============================================================================
# Admin API Templates Tests
# =============================================================================
class TestTemplatesAdminAPI:
"""Tests for /api/v1/admin/templates/* endpoints."""
def test_templates_status_structure(self):
"""Test the structure of templates status response."""
from admin_api import _templates_ingestion_status
# Reset status
_templates_ingestion_status["running"] = False
_templates_ingestion_status["last_run"] = None
_templates_ingestion_status["current_source"] = None
_templates_ingestion_status["results"] = {}
assert _templates_ingestion_status["running"] is False
assert _templates_ingestion_status["results"] == {}
def test_templates_status_running(self):
"""Test status when ingestion is running."""
from admin_api import _templates_ingestion_status
_templates_ingestion_status["running"] = True
_templates_ingestion_status["current_source"] = "github-site-policy"
_templates_ingestion_status["last_run"] = datetime.now().isoformat()
assert _templates_ingestion_status["running"] is True
assert _templates_ingestion_status["current_source"] == "github-site-policy"
def test_templates_results_tracking(self):
"""Test that ingestion results are tracked correctly."""
from admin_api import _templates_ingestion_status
_templates_ingestion_status["results"] = {
"github-site-policy": {
"status": "completed",
"documents_found": 15,
"chunks_indexed": 42,
"errors": [],
},
"opr-vc": {
"status": "failed",
"documents_found": 0,
"chunks_indexed": 0,
"errors": ["Connection timeout"],
},
}
results = _templates_ingestion_status["results"]
assert results["github-site-policy"]["status"] == "completed"
assert results["github-site-policy"]["chunks_indexed"] == 42
assert results["opr-vc"]["status"] == "failed"
assert len(results["opr-vc"]["errors"]) > 0
class TestTemplateTypeLabels:
"""Tests for template type labels and constants."""
def test_template_types_defined(self):
"""Test that template types are properly defined."""
from template_sources import TEMPLATE_TYPES
assert "privacy_policy" in TEMPLATE_TYPES
assert "terms_of_service" in TEMPLATE_TYPES
assert "cookie_banner" in TEMPLATE_TYPES
assert "impressum" in TEMPLATE_TYPES
assert "widerruf" in TEMPLATE_TYPES
assert "dpa" in TEMPLATE_TYPES
def test_jurisdictions_defined(self):
"""Test that jurisdictions are properly defined."""
from template_sources import JURISDICTIONS
assert "DE" in JURISDICTIONS
assert "AT" in JURISDICTIONS
assert "CH" in JURISDICTIONS
assert "EU" in JURISDICTIONS
assert "US" in JURISDICTIONS
# =============================================================================
# Qdrant Service Templates Tests
# =============================================================================
class TestQdrantServiceTemplates:
"""Tests for legal templates Qdrant service functions."""
@pytest.fixture
def mock_qdrant_client(self):
"""Mock Qdrant client for templates."""
with patch('qdrant_service.get_qdrant_client') as mock:
client = MagicMock()
client.get_collections.return_value.collections = []
mock.return_value = client
yield client
def test_legal_templates_collection_name(self):
"""Test that collection name is correct."""
from qdrant_service import LEGAL_TEMPLATES_COLLECTION
assert LEGAL_TEMPLATES_COLLECTION == "bp_legal_templates"
def test_legal_templates_vector_size(self):
"""Test that vector size is correct for BGE-M3."""
from qdrant_service import LEGAL_TEMPLATES_VECTOR_SIZE
assert LEGAL_TEMPLATES_VECTOR_SIZE == 1024
# =============================================================================
# Integration Tests (require mocking external services)
# =============================================================================
class TestTemplatesIntegration:
"""Integration tests for the templates system."""
@pytest.fixture
def mock_all_services(self):
"""Mock all external services."""
with patch('legal_templates_ingestion.QdrantClient') as qdrant_mock, \
patch('legal_templates_ingestion.httpx.AsyncClient') as http_mock:
qdrant = MagicMock()
qdrant.get_collections.return_value.collections = []
qdrant_mock.return_value = qdrant
http = AsyncMock()
http.post.return_value.json.return_value = {"embeddings": [[0.1] * 1024]}
http.post.return_value.raise_for_status = MagicMock()
http_mock.return_value.__aenter__.return_value = http
yield {"qdrant": qdrant, "http": http}
def test_full_chunk_creation_pipeline(self, mock_all_services):
"""Test the full chunk creation pipeline."""
from legal_templates_ingestion import LegalTemplatesIngestion
from github_crawler import ExtractedDocument
from template_sources import SourceConfig, LicenseType
ingestion = LegalTemplatesIngestion()
doc = ExtractedDocument(
text="# Datenschutzerklaerung\n\nWir nehmen den Schutz Ihrer personenbezogenen Daten sehr ernst. Diese Datenschutzerklaerung informiert Sie ueber die Verarbeitung Ihrer Daten gemaess der DSGVO.",
title="Datenschutzerklaerung",
file_path="privacy.md",
file_type="markdown",
source_url="https://example.com/privacy.md",
source_commit="abc123",
placeholders=["[FIRMENNAME]"],
language="de", # Explicitly set language
)
source = SourceConfig(
name="test-source",
license_type=LicenseType.CC0,
template_types=["privacy_policy"],
languages=["de"],
jurisdiction="DE",
description="Test source",
repo_url="https://github.com/test/repo",
)
chunks = ingestion._create_chunks(doc, source)
assert len(chunks) >= 1
assert chunks[0].template_type == "privacy_policy"
assert chunks[0].language == "de"
assert chunks[0].jurisdiction == "DE"
assert chunks[0].license_id == "cc0"
assert chunks[0].attribution_required is False
assert "[FIRMENNAME]" in chunks[0].placeholders
# =============================================================================
# Test Runner Configuration
# =============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,349 @@
"""
Unit Tests for Mail Module
Tests for:
- TaskService: Priority calculation, deadline handling
- AIEmailService: Sender classification, deadline extraction
- Models: Validation, known authorities
"""
import pytest
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
# Import the modules to test
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from mail.models import (
TaskPriority,
TaskStatus,
SenderType,
EmailCategory,
KNOWN_AUTHORITIES_NI,
DeadlineExtraction,
EmailAccountCreate,
TaskCreate,
classify_sender_by_domain,
)
from mail.task_service import TaskService
from mail.ai_service import AIEmailService
class TestKnownAuthoritiesNI:
"""Tests for Niedersachsen authority domain matching."""
def test_kultusministerium_domain(self):
"""Test that MK Niedersachsen domain is recognized."""
assert "@mk.niedersachsen.de" in KNOWN_AUTHORITIES_NI
assert KNOWN_AUTHORITIES_NI["@mk.niedersachsen.de"]["type"] == SenderType.KULTUSMINISTERIUM
def test_rlsb_domain(self):
"""Test that RLSB domain is recognized."""
assert "@rlsb.de" in KNOWN_AUTHORITIES_NI
assert KNOWN_AUTHORITIES_NI["@rlsb.de"]["type"] == SenderType.RLSB
def test_landesschulbehoerde_domain(self):
"""Test that Landesschulbehörde domain is recognized."""
assert "@landesschulbehoerde-nds.de" in KNOWN_AUTHORITIES_NI
assert KNOWN_AUTHORITIES_NI["@landesschulbehoerde-nds.de"]["type"] == SenderType.LANDESSCHULBEHOERDE
def test_nibis_domain(self):
"""Test that NiBiS domain is recognized."""
assert "@nibis.de" in KNOWN_AUTHORITIES_NI
assert KNOWN_AUTHORITIES_NI["@nibis.de"]["type"] == SenderType.NIBIS
def test_unknown_domain_not_in_list(self):
"""Test that unknown domains are not in the list."""
assert "@gmail.com" not in KNOWN_AUTHORITIES_NI
assert "@example.de" not in KNOWN_AUTHORITIES_NI
class TestTaskServicePriority:
"""Tests for TaskService priority calculation."""
@pytest.fixture
def task_service(self):
return TaskService()
def test_priority_from_kultusministerium(self, task_service):
"""Kultusministerium should result in HIGH priority."""
priority = task_service._get_priority_from_sender(SenderType.KULTUSMINISTERIUM)
assert priority == TaskPriority.HIGH
def test_priority_from_rlsb(self, task_service):
"""RLSB should result in HIGH priority."""
priority = task_service._get_priority_from_sender(SenderType.RLSB)
assert priority == TaskPriority.HIGH
def test_priority_from_nibis(self, task_service):
"""NiBiS should result in MEDIUM priority."""
priority = task_service._get_priority_from_sender(SenderType.NIBIS)
assert priority == TaskPriority.MEDIUM
def test_priority_from_privatperson(self, task_service):
"""Privatperson should result in LOW priority."""
priority = task_service._get_priority_from_sender(SenderType.PRIVATPERSON)
assert priority == TaskPriority.LOW
class TestTaskServiceDeadlineAdjustment:
"""Tests for TaskService deadline-based priority adjustment."""
@pytest.fixture
def task_service(self):
return TaskService()
def test_urgent_for_tomorrow(self, task_service):
"""Deadline tomorrow should be URGENT."""
deadline = datetime.now() + timedelta(days=1)
priority = task_service._adjust_priority_for_deadline(TaskPriority.LOW, deadline)
assert priority == TaskPriority.URGENT
def test_urgent_for_today(self, task_service):
"""Deadline today should be URGENT."""
deadline = datetime.now() + timedelta(hours=5)
priority = task_service._adjust_priority_for_deadline(TaskPriority.LOW, deadline)
assert priority == TaskPriority.URGENT
def test_high_for_3_days(self, task_service):
"""Deadline in 3 days with HIGH input stays HIGH."""
deadline = datetime.now() + timedelta(days=3)
# Note: max() compares enum by value string, so we test with HIGH input
priority = task_service._adjust_priority_for_deadline(TaskPriority.HIGH, deadline)
assert priority == TaskPriority.HIGH
def test_medium_for_7_days(self, task_service):
"""Deadline in 7 days should be at least MEDIUM."""
deadline = datetime.now() + timedelta(days=7)
priority = task_service._adjust_priority_for_deadline(TaskPriority.LOW, deadline)
assert priority == TaskPriority.MEDIUM
def test_no_change_for_far_deadline(self, task_service):
"""Deadline far in the future should not change priority."""
deadline = datetime.now() + timedelta(days=30)
priority = task_service._adjust_priority_for_deadline(TaskPriority.LOW, deadline)
assert priority == TaskPriority.LOW
class TestTaskServiceDescriptionBuilder:
"""Tests for TaskService description building."""
@pytest.fixture
def task_service(self):
return TaskService()
def test_description_with_deadlines(self, task_service):
"""Description should include deadline information."""
deadlines = [
DeadlineExtraction(
deadline_date=datetime(2026, 1, 15),
description="Einreichung der Unterlagen",
is_firm=True,
confidence=0.9,
source_text="bis zum 15.01.2026",
)
]
email_data = {
"sender_email": "test@mk.niedersachsen.de",
"body_preview": "Bitte reichen Sie die Unterlagen ein.",
}
description = task_service._build_task_description(deadlines, email_data)
assert "**Fristen:**" in description
assert "15.01.2026" in description
assert "Einreichung der Unterlagen" in description
assert "(verbindlich)" in description
assert "test@mk.niedersachsen.de" in description
def test_description_without_deadlines(self, task_service):
"""Description should work without deadlines."""
email_data = {
"sender_email": "sender@example.de",
"body_preview": "Test preview text",
}
description = task_service._build_task_description([], email_data)
assert "**Fristen:**" not in description
assert "sender@example.de" in description
class TestSenderClassification:
"""Tests for sender classification via classify_sender_by_domain."""
def test_classify_kultusministerium(self):
"""Email from MK should be classified correctly."""
result = classify_sender_by_domain("referat@mk.niedersachsen.de")
assert result is not None
assert result.sender_type == SenderType.KULTUSMINISTERIUM
def test_classify_rlsb(self):
"""Email from RLSB should be classified correctly."""
result = classify_sender_by_domain("info@rlsb.de")
assert result is not None
assert result.sender_type == SenderType.RLSB
def test_classify_unknown_domain(self):
"""Email from unknown domain should return None."""
result = classify_sender_by_domain("user@gmail.com")
assert result is None
class TestAIEmailServiceDeadlineExtraction:
"""Tests for AIEmailService deadline extraction from text."""
@pytest.fixture
def ai_service(self):
return AIEmailService()
def test_extract_deadline_bis_format(self, ai_service):
"""Test extraction of 'bis zum DD.MM.YYYY' format."""
text = "Bitte senden Sie die Unterlagen bis zum 15.01.2027 ein."
deadlines = ai_service._extract_deadlines_regex(text)
assert len(deadlines) >= 1
# Check that at least one deadline was found
dates = [d.deadline_date.strftime("%Y-%m-%d") for d in deadlines]
assert "2027-01-15" in dates
def test_extract_deadline_frist_format(self, ai_service):
"""Test extraction of 'Frist: DD.MM.YYYY' format."""
text = "Die Frist: 20.02.2027 muss eingehalten werden."
deadlines = ai_service._extract_deadlines_regex(text)
assert len(deadlines) >= 1
dates = [d.deadline_date.strftime("%Y-%m-%d") for d in deadlines]
assert "2027-02-20" in dates
def test_no_deadline_in_text(self, ai_service):
"""Test that no deadlines are found when none exist."""
text = "Dies ist eine allgemeine Mitteilung ohne Datum."
deadlines = ai_service._extract_deadlines_regex(text)
assert len(deadlines) == 0
class TestAIEmailServiceCategoryRules:
"""Tests for AIEmailService category classification rules."""
@pytest.fixture
def ai_service(self):
return AIEmailService()
def test_fortbildung_category(self, ai_service):
"""Test Fortbildung category detection."""
# Use keywords that clearly match FORTBILDUNG: fortbildung, seminar, workshop
subject = "Fortbildung NLQ Seminar"
body = "Wir bieten eine Weiterbildung zum Thema Didaktik an."
category, confidence = ai_service._classify_category_rules(subject, body, SenderType.UNBEKANNT)
assert category == EmailCategory.FORTBILDUNG
def test_personal_category(self, ai_service):
"""Test Personal category detection."""
# Use keywords that clearly match PERSONAL: personalrat, versetzung, krankmeldung
subject = "Personalrat Sitzung"
body = "Thema: Krankmeldung und Beurteilung"
category, confidence = ai_service._classify_category_rules(subject, body, SenderType.UNBEKANNT)
assert category == EmailCategory.PERSONAL
def test_finanzen_category(self, ai_service):
"""Test Finanzen category detection."""
# Use keywords that clearly match FINANZEN: budget, haushalt, abrechnung
subject = "Haushalt 2026 Budget"
body = "Die Abrechnung und Erstattung für das neue Etat."
category, confidence = ai_service._classify_category_rules(subject, body, SenderType.UNBEKANNT)
assert category == EmailCategory.FINANZEN
class TestEmailAccountCreateValidation:
"""Tests for EmailAccountCreate Pydantic model validation."""
def test_valid_account_creation(self):
"""Test that valid data creates an account."""
account = EmailAccountCreate(
email="schulleitung@grundschule-xy.de",
display_name="Schulleitung",
imap_host="imap.example.com",
imap_port=993,
smtp_host="smtp.example.com",
smtp_port=587,
password="secret123",
)
assert account.email == "schulleitung@grundschule-xy.de"
assert account.imap_port == 993
assert account.imap_ssl is True # Default
def test_default_ssl_true(self):
"""Test that SSL defaults to True."""
account = EmailAccountCreate(
email="test@example.com",
display_name="Test Account",
imap_host="imap.example.com",
imap_port=993,
smtp_host="smtp.example.com",
smtp_port=587,
password="secret",
)
assert account.imap_ssl is True
assert account.smtp_ssl is True
class TestTaskCreateValidation:
"""Tests for TaskCreate Pydantic model validation."""
def test_valid_task_creation(self):
"""Test that valid data creates a task."""
task = TaskCreate(
title="Unterlagen einreichen",
description="Bitte alle Dokumente bis Freitag.",
priority=TaskPriority.HIGH,
deadline=datetime(2026, 1, 15),
)
assert task.title == "Unterlagen einreichen"
assert task.priority == TaskPriority.HIGH
def test_default_priority_medium(self):
"""Test that priority defaults to MEDIUM."""
task = TaskCreate(
title="Einfache Aufgabe",
)
assert task.priority == TaskPriority.MEDIUM
def test_optional_deadline(self):
"""Test that deadline is optional."""
task = TaskCreate(
title="Keine Frist",
)
assert task.deadline is None
# Integration test placeholder
class TestMailModuleIntegration:
"""Integration tests (require database connection)."""
@pytest.mark.skip(reason="Requires database connection")
@pytest.mark.asyncio
async def test_create_task_from_email(self):
"""Test creating a task from an email analysis."""
pass
@pytest.mark.skip(reason="Requires database connection")
@pytest.mark.asyncio
async def test_dashboard_stats(self):
"""Test dashboard statistics calculation."""
pass
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,799 @@
"""
Tests for OCR Labeling API
Tests session management, image upload, labeling workflow, and training export.
BACKLOG: Feature not yet fully integrated - requires external OCR services
See: https://macmini:3002/infrastructure/tests
"""
import pytest
# Mark all tests in this module as expected failures (backlog item)
pytestmark = pytest.mark.xfail(
reason="ocr_labeling requires external services not available in CI - Backlog item",
strict=False # Don't fail if test unexpectedly passes
)
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
import io
import json
import hashlib
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def mock_db_pool():
"""Mock PostgreSQL connection pool."""
with patch('metrics_db.get_pool') as mock:
pool = AsyncMock()
conn = AsyncMock()
pool.acquire.return_value.__aenter__.return_value = conn
mock.return_value = pool
yield pool, conn
@pytest.fixture
def mock_minio():
"""Mock MinIO storage functions."""
with patch('ocr_labeling_api.MINIO_AVAILABLE', True), \
patch('ocr_labeling_api.upload_ocr_image') as upload_mock, \
patch('ocr_labeling_api.get_ocr_image') as get_mock:
upload_mock.return_value = "ocr-labeling/session-123/item-456.png"
get_mock.return_value = b"\x89PNG fake image data"
yield upload_mock, get_mock
@pytest.fixture
def mock_vision_ocr():
"""Mock Vision OCR service."""
with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', True), \
patch('ocr_labeling_api.get_vision_ocr_service') as mock:
service = AsyncMock()
service.is_available.return_value = True
result = MagicMock()
result.text = "Erkannter Text aus dem Bild"
result.confidence = 0.87
service.extract_text.return_value = result
mock.return_value = service
yield service
@pytest.fixture
def mock_training_export():
"""Mock training export service."""
with patch('ocr_labeling_api.TRAINING_EXPORT_AVAILABLE', True), \
patch('ocr_labeling_api.get_training_export_service') as mock:
service = MagicMock()
export_result = MagicMock()
export_result.export_path = "/app/ocr-exports/generic/20260121_120000"
export_result.manifest_path = "/app/ocr-exports/generic/20260121_120000/manifest.json"
export_result.batch_id = "20260121_120000"
service.export.return_value = export_result
service.list_exports.return_value = [
{"format": "generic", "batch_id": "20260121_120000", "sample_count": 10}
]
mock.return_value = service
yield service
@pytest.fixture
def sample_session():
"""Sample session data."""
return {
"id": "session-123",
"name": "Test Session",
"source_type": "klausur",
"description": "Test description",
"ocr_model": "llama3.2-vision:11b",
"total_items": 5,
"labeled_items": 2,
"confirmed_items": 1,
"corrected_items": 1,
"skipped_items": 0,
"created_at": datetime.utcnow(),
}
@pytest.fixture
def sample_item():
"""Sample labeling item data."""
return {
"id": "item-456",
"session_id": "session-123",
"session_name": "Test Session",
"image_path": "/app/ocr-labeling/session-123/item-456.png",
"ocr_text": "Erkannter Text",
"ocr_confidence": 0.87,
"ground_truth": None,
"status": "pending",
"metadata": {"page": 1},
"created_at": datetime.utcnow(),
}
# =============================================================================
# Session Management Tests
# =============================================================================
class TestSessionCreation:
"""Tests for session creation."""
@pytest.mark.asyncio
async def test_create_session_success(self, mock_db_pool):
"""Test successful session creation."""
from ocr_labeling_api import SessionCreate
from metrics_db import create_ocr_labeling_session
pool, conn = mock_db_pool
conn.execute.return_value = None
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
result = await create_ocr_labeling_session(
session_id="session-123",
name="Test Session",
source_type="klausur",
description="Test",
ocr_model="llama3.2-vision:11b",
)
# Should call execute to insert
assert pool.acquire.called
def test_session_create_model_validation(self):
"""Test SessionCreate model validation."""
from ocr_labeling_api import SessionCreate
# Valid session
session = SessionCreate(
name="Test Session",
source_type="klausur",
description="Test description",
)
assert session.name == "Test Session"
assert session.source_type == "klausur"
assert session.ocr_model == "llama3.2-vision:11b" # default
def test_session_create_with_custom_model(self):
"""Test SessionCreate with custom OCR model."""
from ocr_labeling_api import SessionCreate
session = SessionCreate(
name="TrOCR Session",
source_type="handwriting_sample",
ocr_model="trocr-base",
)
assert session.ocr_model == "trocr-base"
class TestSessionListing:
"""Tests for session listing."""
@pytest.mark.asyncio
async def test_get_sessions_empty(self):
"""Test getting sessions when none exist."""
from metrics_db import get_ocr_labeling_sessions
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
sessions = await get_ocr_labeling_sessions()
assert sessions == []
@pytest.mark.asyncio
async def test_get_session_not_found(self):
"""Test getting a non-existent session."""
from metrics_db import get_ocr_labeling_session
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
session = await get_ocr_labeling_session("non-existent-id")
assert session is None
# =============================================================================
# Image Upload Tests
# =============================================================================
class TestImageUpload:
"""Tests for image upload functionality."""
def test_compute_image_hash(self):
"""Test image hash computation."""
from ocr_labeling_api import compute_image_hash
image_data = b"\x89PNG fake image data"
hash1 = compute_image_hash(image_data)
hash2 = compute_image_hash(image_data)
# Same data should produce same hash
assert hash1 == hash2
assert len(hash1) == 64 # SHA256 hex length
def test_compute_image_hash_different_data(self):
"""Test that different images produce different hashes."""
from ocr_labeling_api import compute_image_hash
hash1 = compute_image_hash(b"image 1 data")
hash2 = compute_image_hash(b"image 2 data")
assert hash1 != hash2
def test_save_image_locally(self, tmp_path):
"""Test local image saving."""
from ocr_labeling_api import save_image_locally, LOCAL_STORAGE_PATH
# Temporarily override storage path
with patch('ocr_labeling_api.LOCAL_STORAGE_PATH', str(tmp_path)):
from ocr_labeling_api import save_image_locally
image_data = b"\x89PNG fake image data"
filepath = save_image_locally(
session_id="session-123",
item_id="item-456",
image_data=image_data,
extension="png",
)
assert filepath.endswith("item-456.png")
# File should exist
import os
assert os.path.exists(filepath)
def test_get_image_url_local(self):
"""Test URL generation for local images."""
from ocr_labeling_api import get_image_url, LOCAL_STORAGE_PATH
local_path = f"{LOCAL_STORAGE_PATH}/session-123/item-456.png"
url = get_image_url(local_path)
assert url == "/api/v1/ocr-label/images/session-123/item-456.png"
def test_get_image_url_minio(self):
"""Test URL for MinIO images (passthrough)."""
from ocr_labeling_api import get_image_url
minio_path = "ocr-labeling/session-123/item-456.png"
url = get_image_url(minio_path)
# Non-local paths are passed through
assert url == minio_path
# =============================================================================
# Labeling Workflow Tests
# =============================================================================
class TestConfirmLabel:
"""Tests for label confirmation."""
@pytest.mark.asyncio
async def test_confirm_label_success(self, mock_db_pool):
"""Test successful label confirmation."""
from metrics_db import confirm_ocr_label
pool, conn = mock_db_pool
conn.fetchrow.return_value = {"ocr_text": "Test text"}
conn.execute.return_value = None
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
result = await confirm_ocr_label(
item_id="item-456",
labeled_by="admin",
label_time_seconds=5,
)
# Should update item status and ground_truth
assert conn.execute.called
def test_confirm_request_validation(self):
"""Test ConfirmRequest model validation."""
from ocr_labeling_api import ConfirmRequest
request = ConfirmRequest(
item_id="item-456",
label_time_seconds=5,
)
assert request.item_id == "item-456"
assert request.label_time_seconds == 5
class TestCorrectLabel:
"""Tests for label correction."""
@pytest.mark.asyncio
async def test_correct_label_success(self, mock_db_pool):
"""Test successful label correction."""
from metrics_db import correct_ocr_label
pool, conn = mock_db_pool
conn.execute.return_value = None
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
result = await correct_ocr_label(
item_id="item-456",
ground_truth="Korrigierter Text",
labeled_by="admin",
label_time_seconds=15,
)
# Should update item with corrected ground_truth
assert conn.execute.called
def test_correct_request_validation(self):
"""Test CorrectRequest model validation."""
from ocr_labeling_api import CorrectRequest
request = CorrectRequest(
item_id="item-456",
ground_truth="Korrigierter Text",
label_time_seconds=15,
)
assert request.item_id == "item-456"
assert request.ground_truth == "Korrigierter Text"
class TestSkipItem:
"""Tests for item skipping."""
@pytest.mark.asyncio
async def test_skip_item_success(self, mock_db_pool):
"""Test successful item skip."""
from metrics_db import skip_ocr_item
pool, conn = mock_db_pool
conn.execute.return_value = None
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
result = await skip_ocr_item(
item_id="item-456",
labeled_by="admin",
)
# Should update item status to skipped
assert conn.execute.called
# =============================================================================
# Statistics Tests
# =============================================================================
class TestLabelingStats:
"""Tests for labeling statistics."""
@pytest.mark.asyncio
async def test_get_stats_no_db(self):
"""Test stats when database is not available."""
from metrics_db import get_ocr_labeling_stats
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
stats = await get_ocr_labeling_stats()
assert "error" in stats or stats.get("total_items", 0) == 0
def test_stats_response_model(self):
"""Test StatsResponse model structure."""
from ocr_labeling_api import StatsResponse
stats = StatsResponse(
total_items=100,
labeled_items=50,
confirmed_items=40,
corrected_items=10,
pending_items=50,
accuracy_rate=0.8,
)
assert stats.total_items == 100
assert stats.accuracy_rate == 0.8
# =============================================================================
# Export Tests
# =============================================================================
class TestTrainingExport:
"""Tests for training data export."""
def test_export_request_validation(self):
"""Test ExportRequest model validation."""
from ocr_labeling_api import ExportRequest
# Default format is generic
request = ExportRequest()
assert request.export_format == "generic"
# TrOCR format
request = ExportRequest(export_format="trocr")
assert request.export_format == "trocr"
# Llama Vision format
request = ExportRequest(export_format="llama_vision")
assert request.export_format == "llama_vision"
@pytest.mark.asyncio
async def test_export_training_samples(self, mock_db_pool):
"""Test training sample export from database."""
from metrics_db import export_training_samples
pool, conn = mock_db_pool
conn.fetch.return_value = [
{
"id": "sample-1",
"image_path": "/app/ocr-labeling/session-123/item-1.png",
"ground_truth": "Text 1",
},
{
"id": "sample-2",
"image_path": "/app/ocr-labeling/session-123/item-2.png",
"ground_truth": "Text 2",
},
]
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
samples = await export_training_samples(
export_format="generic",
exported_by="admin",
)
# Should return exportable samples
assert conn.fetch.called or conn.execute.called
class TestTrainingExportService:
"""Tests for training export service."""
def test_trocr_export_format(self):
"""Test TrOCR export format structure."""
expected_format = {
"file_name": "images/sample-1.png",
"text": "Ground truth text",
"id": "sample-1",
}
assert "file_name" in expected_format
assert "text" in expected_format
def test_llama_vision_export_format(self):
"""Test Llama Vision export format structure."""
expected_format = {
"id": "sample-1",
"messages": [
{"role": "system", "content": "Du bist ein OCR-Experte..."},
{"role": "user", "content": [
{"type": "image_url", "image_url": {"url": "..."}},
{"type": "text", "text": "Lies den Text..."},
]},
{"role": "assistant", "content": "Ground truth text"},
],
}
assert "messages" in expected_format
assert len(expected_format["messages"]) == 3
assert expected_format["messages"][2]["role"] == "assistant"
def test_generic_export_format(self):
"""Test generic export format structure."""
expected_format = {
"id": "sample-1",
"image_path": "images/sample-1.png",
"ground_truth": "Ground truth text",
"ocr_text": "OCR recognized text",
"ocr_confidence": 0.87,
"metadata": {},
}
assert "image_path" in expected_format
assert "ground_truth" in expected_format
# =============================================================================
# OCR Processing Tests
# =============================================================================
class TestOCRProcessing:
"""Tests for OCR processing."""
@pytest.mark.asyncio
async def test_run_ocr_on_image_no_service(self):
"""Test OCR when service is not available."""
from ocr_labeling_api import run_ocr_on_image
with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', False), \
patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False), \
patch('ocr_labeling_api.TROCR_AVAILABLE', False), \
patch('ocr_labeling_api.DONUT_AVAILABLE', False):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
)
assert text is None
assert confidence == 0.0
@pytest.mark.asyncio
async def test_run_ocr_on_image_success(self, mock_vision_ocr):
"""Test successful OCR processing."""
from ocr_labeling_api import run_ocr_on_image
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
)
assert text == "Erkannter Text aus dem Bild"
assert confidence == 0.87
# =============================================================================
# OCR Model Dispatcher Tests
# =============================================================================
class TestOCRModelDispatcher:
"""Tests for the OCR model dispatcher (v1.1.0)."""
@pytest.mark.asyncio
async def test_dispatcher_vision_model_default(self, mock_vision_ocr):
"""Test dispatcher uses Vision OCR by default."""
from ocr_labeling_api import run_ocr_on_image
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="llama3.2-vision:11b",
)
assert text == "Erkannter Text aus dem Bild"
assert confidence == 0.87
@pytest.mark.asyncio
async def test_dispatcher_paddleocr_model(self):
"""Test dispatcher routes to PaddleOCR."""
from ocr_labeling_api import run_ocr_on_image
# Mock PaddleOCR
mock_regions = []
mock_text = "PaddleOCR erkannter Text"
with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', True), \
patch('ocr_labeling_api.run_paddle_ocr', return_value=(mock_regions, mock_text)):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="paddleocr",
)
assert text == "PaddleOCR erkannter Text"
@pytest.mark.asyncio
async def test_dispatcher_paddleocr_fallback_to_vision(self, mock_vision_ocr):
"""Test PaddleOCR falls back to Vision OCR when unavailable."""
from ocr_labeling_api import run_ocr_on_image
with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="paddleocr",
)
# Should fall back to Vision OCR
assert text == "Erkannter Text aus dem Bild"
assert confidence == 0.87
@pytest.mark.asyncio
async def test_dispatcher_trocr_model(self):
"""Test dispatcher routes to TrOCR."""
from ocr_labeling_api import run_ocr_on_image
async def mock_trocr(image_data):
return "TrOCR erkannter Text", 0.85
with patch('ocr_labeling_api.TROCR_AVAILABLE', True), \
patch('ocr_labeling_api.run_trocr_ocr', mock_trocr):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="trocr",
)
assert text == "TrOCR erkannter Text"
assert confidence == 0.85
@pytest.mark.asyncio
async def test_dispatcher_donut_model(self):
"""Test dispatcher routes to Donut."""
from ocr_labeling_api import run_ocr_on_image
async def mock_donut(image_data):
return "Donut erkannter Text", 0.80
with patch('ocr_labeling_api.DONUT_AVAILABLE', True), \
patch('ocr_labeling_api.run_donut_ocr', mock_donut):
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="donut",
)
assert text == "Donut erkannter Text"
assert confidence == 0.80
@pytest.mark.asyncio
async def test_dispatcher_unknown_model_uses_vision(self, mock_vision_ocr):
"""Test dispatcher uses Vision OCR for unknown models."""
from ocr_labeling_api import run_ocr_on_image
text, confidence = await run_ocr_on_image(
image_data=b"fake image",
filename="test.png",
model="unknown-model",
)
# Unknown model should fall back to Vision OCR
assert text == "Erkannter Text aus dem Bild"
assert confidence == 0.87
class TestOCRModelTypes:
"""Tests for OCR model type definitions."""
def test_session_with_paddleocr_model(self):
"""Test session creation with PaddleOCR model."""
from ocr_labeling_api import SessionCreate
session = SessionCreate(
name="PaddleOCR Session",
source_type="klausur",
ocr_model="paddleocr",
)
assert session.ocr_model == "paddleocr"
def test_session_with_donut_model(self):
"""Test session creation with Donut model."""
from ocr_labeling_api import SessionCreate
session = SessionCreate(
name="Donut Session",
source_type="scan",
ocr_model="donut",
)
assert session.ocr_model == "donut"
def test_session_with_trocr_model(self):
"""Test session creation with TrOCR model."""
from ocr_labeling_api import SessionCreate
session = SessionCreate(
name="TrOCR Session",
source_type="handwriting_sample",
ocr_model="trocr",
)
assert session.ocr_model == "trocr"
# =============================================================================
# API Response Model Tests
# =============================================================================
class TestResponseModels:
"""Tests for API response models."""
def test_session_response_model(self):
"""Test SessionResponse model."""
from ocr_labeling_api import SessionResponse
session = SessionResponse(
id="session-123",
name="Test Session",
source_type="klausur",
description="Test",
ocr_model="llama3.2-vision:11b",
total_items=10,
labeled_items=5,
confirmed_items=3,
corrected_items=2,
skipped_items=0,
created_at=datetime.utcnow(),
)
assert session.id == "session-123"
assert session.total_items == 10
def test_item_response_model(self):
"""Test ItemResponse model."""
from ocr_labeling_api import ItemResponse
item = ItemResponse(
id="item-456",
session_id="session-123",
session_name="Test Session",
image_path="/app/ocr-labeling/session-123/item-456.png",
image_url="/api/v1/ocr-label/images/session-123/item-456.png",
ocr_text="Test OCR text",
ocr_confidence=0.87,
ground_truth=None,
status="pending",
metadata={"page": 1},
created_at=datetime.utcnow(),
)
assert item.id == "item-456"
assert item.status == "pending"
# =============================================================================
# Deduplication Tests
# =============================================================================
class TestDeduplication:
"""Tests for image deduplication."""
def test_hash_based_deduplication(self):
"""Test that same images produce same hash for deduplication."""
from ocr_labeling_api import compute_image_hash
# Same content should be detected as duplicate
image1 = b"\x89PNG\x0d\x0a\x1a\x0a test image content"
image2 = b"\x89PNG\x0d\x0a\x1a\x0a test image content"
hash1 = compute_image_hash(image1)
hash2 = compute_image_hash(image2)
assert hash1 == hash2
def test_unique_images_different_hash(self):
"""Test that different images produce different hashes."""
from ocr_labeling_api import compute_image_hash
image1 = b"\x89PNG unique content 1"
image2 = b"\x89PNG unique content 2"
hash1 = compute_image_hash(image1)
hash2 = compute_image_hash(image2)
assert hash1 != hash2
# =============================================================================
# Integration Tests (require running services)
# =============================================================================
class TestOCRLabelingIntegration:
"""Integration tests - require Ollama, MinIO, PostgreSQL running."""
@pytest.mark.skip(reason="Requires running Ollama with llama3.2-vision")
@pytest.mark.asyncio
async def test_full_labeling_workflow(self):
"""Test complete labeling workflow."""
# This would require:
# 1. Create session
# 2. Upload image
# 3. Run OCR
# 4. Confirm or correct label
# 5. Export training data
pass
@pytest.mark.skip(reason="Requires running PostgreSQL")
@pytest.mark.asyncio
async def test_stats_calculation(self):
"""Test statistics calculation with real data."""
pass
@pytest.mark.skip(reason="Requires running MinIO")
@pytest.mark.asyncio
async def test_image_storage_and_retrieval(self):
"""Test image upload and download from MinIO."""
pass
# =============================================================================
# Run Tests
# =============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,356 @@
"""
Tests for RAG Admin API
Tests upload, search, metrics, and storage functionality.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime
import io
import zipfile
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def mock_qdrant_client():
"""Mock Qdrant client."""
with patch('admin_api.get_qdrant_client') as mock:
client = MagicMock()
client.get_collections.return_value.collections = []
client.get_collection.return_value.vectors_count = 7352
client.get_collection.return_value.points_count = 7352
client.get_collection.return_value.status.value = "green"
mock.return_value = client
yield client
@pytest.fixture
def mock_minio_client():
"""Mock MinIO client."""
with patch('minio_storage._get_minio_client') as mock:
client = MagicMock()
client.bucket_exists.return_value = True
client.list_objects.return_value = []
mock.return_value = client
yield client
@pytest.fixture
def mock_db_pool():
"""Mock PostgreSQL connection pool."""
with patch('metrics_db.get_pool') as mock:
pool = AsyncMock()
mock.return_value = pool
yield pool
# =============================================================================
# Admin API Tests
# =============================================================================
class TestIngestionStatus:
"""Tests for /api/v1/admin/nibis/status endpoint."""
def test_status_not_running(self):
"""Test status when no ingestion is running."""
from admin_api import _ingestion_status
# Reset status
_ingestion_status["running"] = False
_ingestion_status["last_run"] = None
_ingestion_status["last_result"] = None
assert _ingestion_status["running"] is False
def test_status_running(self):
"""Test status when ingestion is running."""
from admin_api import _ingestion_status
_ingestion_status["running"] = True
_ingestion_status["last_run"] = datetime.now().isoformat()
assert _ingestion_status["running"] is True
assert _ingestion_status["last_run"] is not None
class TestUploadAPI:
"""Tests for /api/v1/admin/rag/upload endpoint."""
def test_upload_record_creation(self):
"""Test that upload records are created correctly."""
from admin_api import _upload_history
# Clear history
_upload_history.clear()
# Simulate upload record
upload_record = {
"timestamp": datetime.now().isoformat(),
"filename": "test.pdf",
"collection": "bp_nibis_eh",
"year": 2024,
"pdfs_extracted": 1,
"target_directory": "/tmp/test",
}
_upload_history.append(upload_record)
assert len(_upload_history) == 1
assert _upload_history[0]["filename"] == "test.pdf"
def test_upload_history_limit(self):
"""Test that upload history is limited to 100 entries."""
from admin_api import _upload_history
_upload_history.clear()
# Add 105 entries
for i in range(105):
_upload_history.append({
"timestamp": datetime.now().isoformat(),
"filename": f"test_{i}.pdf",
})
if len(_upload_history) > 100:
_upload_history.pop(0)
assert len(_upload_history) == 100
class TestSearchFeedback:
"""Tests for feedback storage."""
def test_feedback_record_format(self):
"""Test feedback record structure."""
feedback_record = {
"timestamp": datetime.now().isoformat(),
"result_id": "test-123",
"rating": 4,
"notes": "Good result",
}
assert "timestamp" in feedback_record
assert feedback_record["rating"] >= 1
assert feedback_record["rating"] <= 5
# =============================================================================
# MinIO Storage Tests
# =============================================================================
class TestMinIOStorage:
"""Tests for MinIO storage functions."""
def test_get_minio_path(self):
"""Test MinIO path generation."""
from minio_storage import get_minio_path
path = get_minio_path(
data_type="landes-daten",
bundesland="ni",
use_case="klausur",
year=2024,
filename="test.pdf",
)
assert path == "landes-daten/ni/klausur/2024/test.pdf"
def test_get_minio_path_teacher_data(self):
"""Test MinIO path for teacher data."""
from minio_storage import get_minio_path
# Teacher data uses different path structure
path = f"lehrer-daten/tenant_123/teacher_456/test.pdf.enc"
assert "lehrer-daten" in path
assert "tenant_123" in path
assert ".enc" in path
@pytest.mark.asyncio
async def test_storage_stats_no_client(self):
"""Test storage stats when MinIO is not available."""
from minio_storage import get_storage_stats
with patch('minio_storage._get_minio_client', return_value=None):
stats = await get_storage_stats()
assert stats["connected"] is False
# =============================================================================
# Metrics DB Tests
# =============================================================================
class TestMetricsDB:
"""Tests for PostgreSQL metrics functions."""
@pytest.mark.asyncio
async def test_store_feedback_no_pool(self):
"""Test feedback storage when DB is not available."""
from metrics_db import store_feedback
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
result = await store_feedback(
result_id="test-123",
rating=4,
)
assert result is False
@pytest.mark.asyncio
async def test_calculate_metrics_no_pool(self):
"""Test metrics calculation when DB is not available."""
from metrics_db import calculate_metrics
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
metrics = await calculate_metrics()
assert metrics["connected"] is False
def test_create_tables_sql_structure(self):
"""Test that SQL table creation is properly structured."""
expected_tables = [
"rag_search_feedback",
"rag_search_logs",
"rag_upload_history",
]
# Read the metrics_db module to check table names
from metrics_db import init_metrics_tables
# The function should create these tables
assert callable(init_metrics_tables)
# =============================================================================
# Integration Tests (require running services)
# =============================================================================
class TestRAGIntegration:
"""Integration tests - require Qdrant, MinIO, PostgreSQL running."""
@pytest.mark.skip(reason="Requires running Qdrant")
@pytest.mark.asyncio
async def test_nibis_search(self):
"""Test NiBiS semantic search."""
from admin_api import search_nibis
from admin_api import NiBiSSearchRequest
request = NiBiSSearchRequest(
query="Gedichtanalyse Expressionismus",
limit=5,
)
# This would require Qdrant running
# results = await search_nibis(request)
# assert len(results) <= 5
@pytest.mark.skip(reason="Requires running MinIO")
@pytest.mark.asyncio
async def test_minio_upload(self):
"""Test MinIO document upload."""
from minio_storage import upload_rag_document
test_content = b"%PDF-1.4 test content"
# This would require MinIO running
# path = await upload_rag_document(
# file_data=test_content,
# filename="test.pdf",
# bundesland="ni",
# use_case="klausur",
# year=2024,
# )
# assert path is not None
@pytest.mark.skip(reason="Requires running PostgreSQL")
@pytest.mark.asyncio
async def test_metrics_storage(self):
"""Test metrics storage in PostgreSQL."""
from metrics_db import store_feedback, calculate_metrics
# This would require PostgreSQL running
# stored = await store_feedback(
# result_id="test-123",
# rating=4,
# query_text="test query",
# )
# assert stored is True
# =============================================================================
# ZIP Handling Tests
# =============================================================================
class TestZIPHandling:
"""Tests for ZIP file extraction."""
def test_create_test_zip(self):
"""Test creating a ZIP file in memory."""
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
zf.writestr("test1.pdf", b"%PDF-1.4 test content 1")
zf.writestr("test2.pdf", b"%PDF-1.4 test content 2")
zf.writestr("subfolder/test3.pdf", b"%PDF-1.4 test content 3")
zip_buffer.seek(0)
# Verify ZIP contents
with zipfile.ZipFile(zip_buffer, 'r') as zf:
names = zf.namelist()
assert "test1.pdf" in names
assert "test2.pdf" in names
assert "subfolder/test3.pdf" in names
def test_filter_macosx_files(self):
"""Test filtering out __MACOSX files from ZIP."""
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
zf.writestr("test.pdf", b"%PDF-1.4 test")
zf.writestr("__MACOSX/._test.pdf", b"macosx metadata")
zip_buffer.seek(0)
with zipfile.ZipFile(zip_buffer, 'r') as zf:
pdfs = [
name for name in zf.namelist()
if name.lower().endswith(".pdf") and not name.startswith("__MACOSX")
]
assert len(pdfs) == 1
assert pdfs[0] == "test.pdf"
# =============================================================================
# Embedding Tests
# =============================================================================
class TestEmbeddings:
"""Tests for embedding generation."""
def test_vector_dimensions(self):
"""Test that vector dimensions are configured correctly."""
from eh_pipeline import get_vector_size, EMBEDDING_BACKEND
size = get_vector_size()
if EMBEDDING_BACKEND == "local":
assert size == 384 # all-MiniLM-L6-v2
elif EMBEDDING_BACKEND == "openai":
assert size == 1536 # text-embedding-3-small
def test_chunking_config(self):
"""Test chunking configuration."""
from eh_pipeline import CHUNK_SIZE, CHUNK_OVERLAP
assert CHUNK_SIZE > 0
assert CHUNK_OVERLAP >= 0
assert CHUNK_OVERLAP < CHUNK_SIZE
# =============================================================================
# Run Tests
# =============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,623 @@
"""
Unit Tests for Vocab-Worksheet API
Tests cover:
- Session CRUD (create, read, list, delete)
- File upload (images and PDFs)
- PDF page handling (thumbnails, page selection)
- Vocabulary extraction (mocked Vision LLM)
- Vocabulary editing
- Worksheet generation
- PDF export
DSGVO Note: All tests run locally without external API calls.
BACKLOG: Feature not yet integrated into main.py
See: https://macmini:3002/infrastructure/tests
"""
import pytest
import sys
# Skip entire module if vocab_worksheet_api is not available
pytest.importorskip("vocab_worksheet_api", reason="vocab_worksheet_api not yet integrated - Backlog item")
# Mark all tests in this module as expected failures (backlog item)
pytestmark = pytest.mark.xfail(
reason="vocab_worksheet_api not yet integrated into main.py - Backlog item",
strict=False # Don't fail if test unexpectedly passes
)
import json
import uuid
import io
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
# Import the main app and vocab-worksheet components
sys.path.insert(0, '..')
from main import app
from vocab_worksheet_api import (
_sessions,
_worksheets,
SessionStatus,
WorksheetType,
VocabularyEntry,
SessionCreate,
VocabularyUpdate,
WorksheetGenerateRequest,
parse_vocabulary_json,
)
# =============================================
# FIXTURES
# =============================================
@pytest.fixture
def client():
"""Test client for FastAPI app."""
return TestClient(app)
@pytest.fixture(autouse=True)
def clear_storage():
"""Clear in-memory storage before each test."""
_sessions.clear()
_worksheets.clear()
yield
_sessions.clear()
_worksheets.clear()
@pytest.fixture
def sample_session_data():
"""Sample session creation data."""
return {
"name": "Englisch Klasse 7 - Unit 3",
"description": "Vokabeln aus Green Line 3",
"source_language": "en",
"target_language": "de"
}
@pytest.fixture
def sample_vocabulary():
"""Sample vocabulary entries."""
return [
{
"id": str(uuid.uuid4()),
"english": "to achieve",
"german": "erreichen, erzielen",
"example_sentence": "She achieved her goals.",
"word_type": "v",
"source_page": 1
},
{
"id": str(uuid.uuid4()),
"english": "achievement",
"german": "Leistung, Errungenschaft",
"example_sentence": "That was a great achievement.",
"word_type": "n",
"source_page": 1
},
{
"id": str(uuid.uuid4()),
"english": "improve",
"german": "verbessern",
"example_sentence": "I want to improve my English.",
"word_type": "v",
"source_page": 1
}
]
@pytest.fixture
def sample_image_bytes():
"""Create a minimal valid PNG image (1x1 pixel, white)."""
# Minimal PNG: 1x1 white pixel
png_data = bytes([
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, # PNG signature
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, # IHDR chunk
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, # 1x1
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53,
0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, # IDAT chunk
0x54, 0x08, 0xD7, 0x63, 0xF8, 0xFF, 0xFF, 0xFF,
0x00, 0x05, 0xFE, 0x02, 0xFE, 0xDC, 0xCC, 0x59,
0xE7, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, # IEND chunk
0x44, 0xAE, 0x42, 0x60, 0x82
])
return png_data
# =============================================
# SESSION TESTS
# =============================================
class TestSessionCRUD:
"""Test session create, read, update, delete operations."""
def test_create_session(self, client, sample_session_data):
"""Test creating a new vocabulary session."""
response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["name"] == sample_session_data["name"]
assert data["description"] == sample_session_data["description"]
assert data["source_language"] == "en"
assert data["target_language"] == "de"
assert data["status"] == "pending"
assert data["vocabulary_count"] == 0
def test_create_session_minimal(self, client):
"""Test creating session with minimal data."""
response = client.post("/api/v1/vocab/sessions", json={"name": "Test"})
assert response.status_code == 200
data = response.json()
assert data["name"] == "Test"
assert data["source_language"] == "en" # Default
assert data["target_language"] == "de" # Default
def test_list_sessions_empty(self, client):
"""Test listing sessions when none exist."""
response = client.get("/api/v1/vocab/sessions")
assert response.status_code == 200
assert response.json() == []
def test_list_sessions(self, client, sample_session_data):
"""Test listing sessions after creating some."""
# Create 3 sessions
for i in range(3):
data = sample_session_data.copy()
data["name"] = f"Session {i+1}"
client.post("/api/v1/vocab/sessions", json=data)
response = client.get("/api/v1/vocab/sessions")
assert response.status_code == 200
sessions = response.json()
assert len(sessions) == 3
def test_get_session(self, client, sample_session_data):
"""Test getting a specific session."""
# Create session
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Get session
response = client.get(f"/api/v1/vocab/sessions/{session_id}")
assert response.status_code == 200
data = response.json()
assert data["id"] == session_id
assert data["name"] == sample_session_data["name"]
def test_get_session_not_found(self, client):
"""Test getting non-existent session."""
fake_id = str(uuid.uuid4())
response = client.get(f"/api/v1/vocab/sessions/{fake_id}")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_delete_session(self, client, sample_session_data):
"""Test deleting a session."""
# Create session
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Delete session
response = client.delete(f"/api/v1/vocab/sessions/{session_id}")
assert response.status_code == 200
assert "deleted" in response.json()["message"].lower()
# Verify it's gone
get_response = client.get(f"/api/v1/vocab/sessions/{session_id}")
assert get_response.status_code == 404
def test_delete_session_not_found(self, client):
"""Test deleting non-existent session."""
fake_id = str(uuid.uuid4())
response = client.delete(f"/api/v1/vocab/sessions/{fake_id}")
assert response.status_code == 404
# =============================================
# VOCABULARY TESTS
# =============================================
class TestVocabulary:
"""Test vocabulary operations."""
def test_get_vocabulary_empty(self, client, sample_session_data):
"""Test getting vocabulary from session with no vocabulary."""
# Create session
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Get vocabulary
response = client.get(f"/api/v1/vocab/sessions/{session_id}/vocabulary")
assert response.status_code == 200
data = response.json()
assert data["session_id"] == session_id
assert data["vocabulary"] == []
def test_update_vocabulary(self, client, sample_session_data, sample_vocabulary):
"""Test updating vocabulary entries."""
# Create session
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Update vocabulary
response = client.put(
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
json={"vocabulary": sample_vocabulary}
)
assert response.status_code == 200
data = response.json()
assert data["vocabulary_count"] == 3
# Verify vocabulary was saved
get_response = client.get(f"/api/v1/vocab/sessions/{session_id}/vocabulary")
vocab_data = get_response.json()
assert len(vocab_data["vocabulary"]) == 3
def test_update_vocabulary_not_found(self, client, sample_vocabulary):
"""Test updating vocabulary for non-existent session."""
fake_id = str(uuid.uuid4())
response = client.put(
f"/api/v1/vocab/sessions/{fake_id}/vocabulary",
json={"vocabulary": sample_vocabulary}
)
assert response.status_code == 404
# =============================================
# WORKSHEET GENERATION TESTS
# =============================================
class TestWorksheetGeneration:
"""Test worksheet generation."""
def test_generate_worksheet_no_vocabulary(self, client, sample_session_data):
"""Test generating worksheet without vocabulary fails."""
# Create session (no vocabulary)
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Try to generate worksheet
response = client.post(
f"/api/v1/vocab/sessions/{session_id}/generate",
json={
"worksheet_types": ["en_to_de"],
"include_solutions": True
}
)
assert response.status_code == 400
assert "no vocabulary" in response.json()["detail"].lower()
@patch('vocab_worksheet_api.generate_worksheet_pdf')
def test_generate_worksheet_success(
self, mock_pdf, client, sample_session_data, sample_vocabulary
):
"""Test successful worksheet generation."""
# Mock PDF generation to return fake bytes
mock_pdf.return_value = b"%PDF-1.4 fake pdf content"
# Create session with vocabulary
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Add vocabulary
client.put(
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
json={"vocabulary": sample_vocabulary}
)
# Generate worksheet
response = client.post(
f"/api/v1/vocab/sessions/{session_id}/generate",
json={
"worksheet_types": ["en_to_de", "de_to_en"],
"include_solutions": True,
"line_height": "large"
}
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["session_id"] == session_id
assert "en_to_de" in data["worksheet_types"]
assert "de_to_en" in data["worksheet_types"]
def test_generate_worksheet_all_types(self, client, sample_session_data, sample_vocabulary):
"""Test that all worksheet types are accepted."""
# Create session with vocabulary
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Add vocabulary
client.put(
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
json={"vocabulary": sample_vocabulary}
)
# Test each worksheet type
for wtype in ["en_to_de", "de_to_en", "copy", "gap_fill"]:
with patch('vocab_worksheet_api.generate_worksheet_pdf') as mock_pdf:
mock_pdf.return_value = b"%PDF-1.4 fake"
response = client.post(
f"/api/v1/vocab/sessions/{session_id}/generate",
json={"worksheet_types": [wtype]}
)
assert response.status_code == 200, f"Failed for type: {wtype}"
# =============================================
# JSON PARSING TESTS
# =============================================
class TestJSONParsing:
"""Test vocabulary JSON parsing from LLM responses."""
def test_parse_valid_json(self):
"""Test parsing valid JSON response."""
response = '''
{
"vocabulary": [
{"english": "achieve", "german": "erreichen"},
{"english": "improve", "german": "verbessern"}
]
}
'''
result = parse_vocabulary_json(response)
assert len(result) == 2
assert result[0].english == "achieve"
assert result[0].german == "erreichen"
def test_parse_json_with_extra_text(self):
"""Test parsing JSON with surrounding text."""
response = '''
Here is the extracted vocabulary:
{
"vocabulary": [
{"english": "success", "german": "Erfolg"}
]
}
I found 1 vocabulary entry.
'''
result = parse_vocabulary_json(response)
assert len(result) == 1
assert result[0].english == "success"
def test_parse_json_with_examples(self):
"""Test parsing JSON with example sentences."""
response = '''
{
"vocabulary": [
{
"english": "achieve",
"german": "erreichen",
"example": "She achieved her goals."
}
]
}
'''
result = parse_vocabulary_json(response)
assert len(result) == 1
assert result[0].example_sentence == "She achieved her goals."
def test_parse_empty_response(self):
"""Test parsing empty/invalid response."""
result = parse_vocabulary_json("")
assert result == []
result = parse_vocabulary_json("no json here")
assert result == []
def test_parse_json_missing_fields(self):
"""Test that entries without required fields are skipped."""
response = '''
{
"vocabulary": [
{"english": "valid", "german": "gueltig"},
{"english": ""},
{"german": "nur deutsch"},
{"english": "also valid", "german": "auch gueltig"}
]
}
'''
result = parse_vocabulary_json(response)
# Only entries with both english and german should be included
assert len(result) == 2
assert result[0].english == "valid"
assert result[1].english == "also valid"
# =============================================
# FILE UPLOAD TESTS
# =============================================
class TestFileUpload:
"""Test file upload functionality."""
@patch('vocab_worksheet_api.extract_vocabulary_from_image')
def test_upload_image(self, mock_extract, client, sample_session_data, sample_image_bytes):
"""Test uploading an image file."""
# Mock extraction to return sample vocabulary
mock_extract.return_value = (
[
VocabularyEntry(
id=str(uuid.uuid4()),
english="test",
german="Test"
)
],
0.85,
""
)
# Create session
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Upload image
files = {"file": ("test.png", io.BytesIO(sample_image_bytes), "image/png")}
response = client.post(
f"/api/v1/vocab/sessions/{session_id}/upload",
files=files
)
assert response.status_code == 200
data = response.json()
assert data["session_id"] == session_id
assert data["vocabulary_count"] == 1
def test_upload_invalid_file_type(self, client, sample_session_data):
"""Test uploading invalid file type."""
# Create session
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Try to upload a text file
files = {"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")}
response = client.post(
f"/api/v1/vocab/sessions/{session_id}/upload",
files=files
)
assert response.status_code == 400
assert "supported" in response.json()["detail"].lower()
# =============================================
# STATUS WORKFLOW TESTS
# =============================================
class TestSessionStatus:
"""Test session status transitions."""
def test_initial_status_pending(self, client, sample_session_data):
"""Test that new session has PENDING status."""
response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
assert response.json()["status"] == "pending"
@patch('vocab_worksheet_api.extract_vocabulary_from_image')
def test_status_after_extraction(self, mock_extract, client, sample_session_data, sample_image_bytes):
"""Test that status becomes EXTRACTED after processing."""
mock_extract.return_value = ([], 0.0, "")
# Create and upload
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
files = {"file": ("test.png", io.BytesIO(sample_image_bytes), "image/png")}
client.post(f"/api/v1/vocab/sessions/{session_id}/upload", files=files)
# Check status
get_response = client.get(f"/api/v1/vocab/sessions/{session_id}")
assert get_response.json()["status"] == "extracted"
@patch('vocab_worksheet_api.generate_worksheet_pdf')
def test_status_after_generation(self, mock_pdf, client, sample_session_data, sample_vocabulary):
"""Test that status becomes COMPLETED after worksheet generation."""
mock_pdf.return_value = b"%PDF"
# Create session with vocabulary
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Add vocabulary
client.put(
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
json={"vocabulary": sample_vocabulary}
)
# Generate worksheet
client.post(
f"/api/v1/vocab/sessions/{session_id}/generate",
json={"worksheet_types": ["en_to_de"]}
)
# Check status
get_response = client.get(f"/api/v1/vocab/sessions/{session_id}")
assert get_response.json()["status"] == "completed"
# =============================================
# EDGE CASES
# =============================================
class TestEdgeCases:
"""Test edge cases and error handling."""
def test_session_with_special_characters(self, client):
"""Test session with special characters in name."""
response = client.post(
"/api/v1/vocab/sessions",
json={"name": "Englisch Klasse 7 - äöü ß € @"}
)
assert response.status_code == 200
assert "äöü" in response.json()["name"]
def test_vocabulary_with_long_entries(self, client, sample_session_data):
"""Test vocabulary with very long entries."""
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
session_id = create_response.json()["id"]
# Create vocabulary with long entries
long_vocab = [{
"id": str(uuid.uuid4()),
"english": "a" * 100,
"german": "b" * 200,
"example_sentence": "c" * 500
}]
response = client.put(
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
json={"vocabulary": long_vocab}
)
assert response.status_code == 200
def test_sessions_limit(self, client, sample_session_data):
"""Test session listing with limit parameter."""
# Create 10 sessions
for i in range(10):
data = sample_session_data.copy()
data["name"] = f"Session {i+1}"
client.post("/api/v1/vocab/sessions", json=data)
# Get with limit
response = client.get("/api/v1/vocab/sessions?limit=5")
assert response.status_code == 200
assert len(response.json()) == 5
# =============================================
# RUN TESTS
# =============================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,539 @@
"""
Unit Tests for Worksheet Editor API
Tests cover:
- Worksheet CRUD (create, read, list, delete)
- AI Image generation (mocked)
- PDF Export
- Health check
DSGVO Note: All tests run locally without external API calls.
BACKLOG: Feature not yet integrated into main.py
See: https://macmini:3002/infrastructure/tests
"""
import pytest
import sys
# Skip entire module if worksheet_editor_api is not available
pytest.importorskip("worksheet_editor_api", reason="worksheet_editor_api not yet integrated - Backlog item")
# Mark all tests in this module as expected failures (backlog item)
pytestmark = pytest.mark.xfail(
reason="worksheet_editor_api not yet integrated into main.py - Backlog item",
strict=False # Don't fail if test unexpectedly passes
)
import json
import uuid
import io
import os
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
# Import the main app and worksheet-editor components
sys.path.insert(0, '..')
from main import app
from worksheet_editor_api import (
worksheets_db,
AIImageStyle,
WorksheetStatus,
WORKSHEET_STORAGE_DIR,
)
# =============================================
# FIXTURES
# =============================================
@pytest.fixture
def client():
"""Test client for FastAPI app."""
return TestClient(app)
@pytest.fixture(autouse=True)
def clear_storage():
"""Clear in-memory storage before each test."""
worksheets_db.clear()
# Clean up test files
if os.path.exists(WORKSHEET_STORAGE_DIR):
for f in os.listdir(WORKSHEET_STORAGE_DIR):
if f.startswith('test_') or f.startswith('ws_test_'):
os.remove(os.path.join(WORKSHEET_STORAGE_DIR, f))
yield
worksheets_db.clear()
@pytest.fixture
def sample_worksheet_data():
"""Sample worksheet data."""
return {
"title": "Test Arbeitsblatt",
"description": "Testbeschreibung",
"pages": [
{
"id": "page_1",
"index": 0,
"canvasJSON": json.dumps({
"objects": [
{
"type": "text",
"text": "Überschrift",
"left": 100,
"top": 50,
"fontSize": 24
},
{
"type": "rect",
"left": 100,
"top": 100,
"width": 200,
"height": 100
}
]
})
}
],
"pageFormat": {
"width": 210,
"height": 297,
"orientation": "portrait",
"margins": {"top": 15, "right": 15, "bottom": 15, "left": 15}
}
}
@pytest.fixture
def sample_multipage_worksheet():
"""Sample worksheet with multiple pages."""
return {
"title": "Mehrseitiges Arbeitsblatt",
"pages": [
{"id": "page_1", "index": 0, "canvasJSON": "{}"},
{"id": "page_2", "index": 1, "canvasJSON": "{}"},
{"id": "page_3", "index": 2, "canvasJSON": "{}"},
]
}
# =============================================
# WORKSHEET CRUD TESTS
# =============================================
class TestWorksheetCRUD:
"""Tests for Worksheet CRUD operations."""
def test_create_worksheet(self, client, sample_worksheet_data):
"""Test creating a new worksheet."""
response = client.post(
"/api/v1/worksheet/save",
json=sample_worksheet_data
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["id"].startswith("ws_")
assert data["title"] == sample_worksheet_data["title"]
assert data["description"] == sample_worksheet_data["description"]
assert len(data["pages"]) == 1
assert "createdAt" in data
assert "updatedAt" in data
def test_update_worksheet(self, client, sample_worksheet_data):
"""Test updating an existing worksheet."""
# First create
create_response = client.post(
"/api/v1/worksheet/save",
json=sample_worksheet_data
)
worksheet_id = create_response.json()["id"]
# Then update
sample_worksheet_data["id"] = worksheet_id
sample_worksheet_data["title"] = "Aktualisiertes Arbeitsblatt"
update_response = client.post(
"/api/v1/worksheet/save",
json=sample_worksheet_data
)
assert update_response.status_code == 200
data = update_response.json()
assert data["id"] == worksheet_id
assert data["title"] == "Aktualisiertes Arbeitsblatt"
def test_get_worksheet(self, client, sample_worksheet_data):
"""Test retrieving a worksheet by ID."""
# Create first
create_response = client.post(
"/api/v1/worksheet/save",
json=sample_worksheet_data
)
worksheet_id = create_response.json()["id"]
# Then get
get_response = client.get(f"/api/v1/worksheet/{worksheet_id}")
assert get_response.status_code == 200
data = get_response.json()
assert data["id"] == worksheet_id
assert data["title"] == sample_worksheet_data["title"]
def test_get_nonexistent_worksheet(self, client):
"""Test retrieving a non-existent worksheet."""
response = client.get("/api/v1/worksheet/ws_nonexistent123")
assert response.status_code == 404
assert "not found" in response.json()["detail"].lower()
def test_list_worksheets(self, client, sample_worksheet_data):
"""Test listing all worksheets."""
# Create multiple worksheets
for i in range(3):
data = sample_worksheet_data.copy()
data["title"] = f"Arbeitsblatt {i+1}"
client.post("/api/v1/worksheet/save", json=data)
# List all
response = client.get("/api/v1/worksheet/list/all")
assert response.status_code == 200
data = response.json()
assert "worksheets" in data
assert "total" in data
assert data["total"] >= 3
def test_delete_worksheet(self, client, sample_worksheet_data):
"""Test deleting a worksheet."""
# Create first
create_response = client.post(
"/api/v1/worksheet/save",
json=sample_worksheet_data
)
worksheet_id = create_response.json()["id"]
# Delete
delete_response = client.delete(f"/api/v1/worksheet/{worksheet_id}")
assert delete_response.status_code == 200
assert delete_response.json()["status"] == "deleted"
# Verify it's gone
get_response = client.get(f"/api/v1/worksheet/{worksheet_id}")
assert get_response.status_code == 404
def test_delete_nonexistent_worksheet(self, client):
"""Test deleting a non-existent worksheet."""
response = client.delete("/api/v1/worksheet/ws_nonexistent123")
assert response.status_code == 404
# =============================================
# MULTIPAGE TESTS
# =============================================
class TestMultipageWorksheet:
"""Tests for multi-page worksheet functionality."""
def test_create_multipage_worksheet(self, client, sample_multipage_worksheet):
"""Test creating a worksheet with multiple pages."""
response = client.post(
"/api/v1/worksheet/save",
json=sample_multipage_worksheet
)
assert response.status_code == 200
data = response.json()
assert len(data["pages"]) == 3
def test_page_indices(self, client, sample_multipage_worksheet):
"""Test that page indices are preserved."""
response = client.post(
"/api/v1/worksheet/save",
json=sample_multipage_worksheet
)
data = response.json()
for i, page in enumerate(data["pages"]):
assert page["index"] == i
# =============================================
# AI IMAGE TESTS
# =============================================
class TestAIImageGeneration:
"""Tests for AI image generation."""
def test_ai_image_with_valid_prompt(self, client):
"""Test AI image generation with valid prompt."""
response = client.post(
"/api/v1/worksheet/ai-image",
json={
"prompt": "A friendly cartoon dog reading a book",
"style": "cartoon",
"width": 256,
"height": 256
}
)
# Should return 200 (with placeholder) since Ollama may not be running
assert response.status_code == 200
data = response.json()
assert "image_base64" in data
assert data["image_base64"].startswith("data:image/png;base64,")
assert "prompt_used" in data
def test_ai_image_styles(self, client):
"""Test different AI image styles."""
styles = ["realistic", "cartoon", "sketch", "clipart", "educational"]
for style in styles:
response = client.post(
"/api/v1/worksheet/ai-image",
json={
"prompt": "Test image",
"style": style,
"width": 256,
"height": 256
}
)
assert response.status_code == 200
data = response.json()
assert style in data["prompt_used"].lower() or "test image" in data["prompt_used"].lower()
def test_ai_image_empty_prompt(self, client):
"""Test AI image generation with empty prompt."""
response = client.post(
"/api/v1/worksheet/ai-image",
json={
"prompt": "",
"style": "educational",
"width": 256,
"height": 256
}
)
assert response.status_code == 422 # Validation error
def test_ai_image_invalid_dimensions(self, client):
"""Test AI image generation with invalid dimensions."""
response = client.post(
"/api/v1/worksheet/ai-image",
json={
"prompt": "Test",
"style": "educational",
"width": 50, # Too small
"height": 256
}
)
assert response.status_code == 422 # Validation error
# =============================================
# CANVAS JSON TESTS
# =============================================
class TestCanvasJSON:
"""Tests for canvas JSON handling."""
def test_save_and_load_canvas_json(self, client):
"""Test that canvas JSON is preserved correctly."""
canvas_data = {
"objects": [
{"type": "text", "text": "Hello", "left": 10, "top": 10},
{"type": "rect", "left": 50, "top": 50, "width": 100, "height": 50}
],
"background": "#ffffff"
}
worksheet_data = {
"title": "Canvas Test",
"pages": [{
"id": "page_1",
"index": 0,
"canvasJSON": json.dumps(canvas_data)
}]
}
# Save
create_response = client.post(
"/api/v1/worksheet/save",
json=worksheet_data
)
worksheet_id = create_response.json()["id"]
# Load
get_response = client.get(f"/api/v1/worksheet/{worksheet_id}")
loaded_data = get_response.json()
# Compare
loaded_canvas = json.loads(loaded_data["pages"][0]["canvasJSON"])
assert loaded_canvas["objects"] == canvas_data["objects"]
def test_empty_canvas_json(self, client):
"""Test worksheet with empty canvas JSON."""
worksheet_data = {
"title": "Empty Canvas",
"pages": [{
"id": "page_1",
"index": 0,
"canvasJSON": "{}"
}]
}
response = client.post(
"/api/v1/worksheet/save",
json=worksheet_data
)
assert response.status_code == 200
# =============================================
# PAGE FORMAT TESTS
# =============================================
class TestPageFormat:
"""Tests for page format handling."""
def test_default_page_format(self, client):
"""Test that default page format is applied."""
worksheet_data = {
"title": "Default Format",
"pages": [{"id": "p1", "index": 0, "canvasJSON": "{}"}]
}
response = client.post(
"/api/v1/worksheet/save",
json=worksheet_data
)
data = response.json()
assert data["pageFormat"]["width"] == 210
assert data["pageFormat"]["height"] == 297
assert data["pageFormat"]["orientation"] == "portrait"
def test_custom_page_format(self, client):
"""Test custom page format."""
worksheet_data = {
"title": "Custom Format",
"pages": [{"id": "p1", "index": 0, "canvasJSON": "{}"}],
"pageFormat": {
"width": 297,
"height": 210,
"orientation": "landscape",
"margins": {"top": 20, "right": 20, "bottom": 20, "left": 20}
}
}
response = client.post(
"/api/v1/worksheet/save",
json=worksheet_data
)
data = response.json()
assert data["pageFormat"]["width"] == 297
assert data["pageFormat"]["height"] == 210
assert data["pageFormat"]["orientation"] == "landscape"
# =============================================
# HEALTH CHECK TESTS
# =============================================
class TestHealthCheck:
"""Tests for health check endpoint."""
def test_health_check(self, client):
"""Test health check endpoint."""
response = client.get("/api/v1/worksheet/health/check")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert data["status"] == "healthy"
assert "storage" in data
assert "reportlab" in data
# =============================================
# PDF EXPORT TESTS
# =============================================
class TestPDFExport:
"""Tests for PDF export functionality."""
def test_export_pdf(self, client, sample_worksheet_data):
"""Test PDF export of a worksheet."""
# Create worksheet
create_response = client.post(
"/api/v1/worksheet/save",
json=sample_worksheet_data
)
worksheet_id = create_response.json()["id"]
# Export
export_response = client.post(
f"/api/v1/worksheet/{worksheet_id}/export-pdf"
)
# Check response (may be 501 if reportlab not installed)
assert export_response.status_code in [200, 501]
if export_response.status_code == 200:
assert export_response.headers["content-type"] == "application/pdf"
assert "attachment" in export_response.headers.get("content-disposition", "")
def test_export_nonexistent_worksheet_pdf(self, client):
"""Test PDF export of non-existent worksheet."""
response = client.post("/api/v1/worksheet/ws_nonexistent/export-pdf")
assert response.status_code in [404, 501]
# =============================================
# ERROR HANDLING TESTS
# =============================================
class TestErrorHandling:
"""Tests for error handling."""
def test_invalid_json(self, client):
"""Test handling of invalid JSON."""
response = client.post(
"/api/v1/worksheet/save",
content="not valid json",
headers={"Content-Type": "application/json"}
)
assert response.status_code == 422
def test_missing_required_fields(self, client):
"""Test handling of missing required fields."""
response = client.post(
"/api/v1/worksheet/save",
json={"pages": []} # Missing title
)
assert response.status_code == 422
def test_empty_pages_array(self, client):
"""Test worksheet with empty pages array."""
response = client.post(
"/api/v1/worksheet/save",
json={
"title": "No Pages",
"pages": []
}
)
# Should still work - empty worksheets are allowed
assert response.status_code == 200