fix: Restore all files lost during destructive rebase
A previous `git pull --rebase origin main` dropped 177 local commits,
losing 3400+ files across admin-v2, backend, studio-v2, website,
klausur-service, and many other services. The partial restore attempt
(660295e2) only recovered some files.
This commit restores all missing files from pre-rebase ref 98933f5e
while preserving post-rebase additions (night-scheduler, night-mode UI,
NightModeWidget dashboard integration).
Restored features include:
- AI Module Sidebar (FAB), OCR Labeling, OCR Compare
- GPU Dashboard, RAG Pipeline, Magic Help
- Klausur-Korrektur (8 files), Abitur-Archiv (5+ files)
- Companion, Zeugnisse-Crawler, Screen Flow
- Full backend, studio-v2, website, klausur-service
- All compliance SDKs, agent-core, voice-service
- CI/CD configs, documentation, scripts
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
1
klausur-service/backend/tests/__init__.py
Normal file
1
klausur-service/backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# BYOEH Test Suite
|
||||
14
klausur-service/backend/tests/conftest.py
Normal file
14
klausur-service/backend/tests/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Pytest configuration for klausur-service tests.
|
||||
|
||||
Ensures local modules (hyde, hybrid_search, rag_evaluation, etc.)
|
||||
can be imported by adding the backend directory to sys.path.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the backend directory to sys.path so local modules can be imported
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
if str(backend_dir) not in sys.path:
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
769
klausur-service/backend/tests/test_advanced_rag.py
Normal file
769
klausur-service/backend/tests/test_advanced_rag.py
Normal file
@@ -0,0 +1,769 @@
|
||||
"""
|
||||
Tests for Advanced RAG Features
|
||||
|
||||
Tests for the newly implemented RAG quality improvements:
|
||||
- HyDE (Hypothetical Document Embeddings)
|
||||
- Hybrid Search (Dense + Sparse/BM25)
|
||||
- RAG Evaluation (RAGAS-inspired metrics)
|
||||
- PDF Extraction (Unstructured.io, PyMuPDF, PyPDF2)
|
||||
- Self-RAG / Corrective RAG
|
||||
|
||||
Run with: pytest tests/test_advanced_rag.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import asyncio
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# HyDE Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestHyDE:
|
||||
"""Tests for HyDE (Hypothetical Document Embeddings) module."""
|
||||
|
||||
def test_hyde_config(self):
|
||||
"""Test HyDE configuration loading."""
|
||||
from hyde import HYDE_ENABLED, HYDE_LLM_BACKEND, HYDE_MODEL, get_hyde_info
|
||||
|
||||
info = get_hyde_info()
|
||||
assert "enabled" in info
|
||||
assert "llm_backend" in info
|
||||
assert "model" in info
|
||||
|
||||
def test_hyde_prompt_template(self):
|
||||
"""Test that HyDE prompt template is properly formatted."""
|
||||
from hyde import HYDE_PROMPT_TEMPLATE
|
||||
|
||||
assert "{query}" in HYDE_PROMPT_TEMPLATE
|
||||
assert "Erwartungshorizont" in HYDE_PROMPT_TEMPLATE
|
||||
assert "Bildungsstandards" in HYDE_PROMPT_TEMPLATE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_hypothetical_document_disabled(self):
|
||||
"""Test HyDE returns original query when disabled."""
|
||||
from hyde import generate_hypothetical_document
|
||||
|
||||
with patch('hyde.HYDE_ENABLED', False):
|
||||
result = await generate_hypothetical_document("Test query")
|
||||
assert result == "Test query"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hyde_search_fallback(self):
|
||||
"""Test hyde_search falls back gracefully without LLM."""
|
||||
from hyde import hyde_search
|
||||
|
||||
async def mock_search(query, **kwargs):
|
||||
return [{"id": "1", "text": "Result for: " + query}]
|
||||
|
||||
with patch('hyde.HYDE_ENABLED', False):
|
||||
result = await hyde_search(
|
||||
query="Test query",
|
||||
search_func=mock_search,
|
||||
)
|
||||
|
||||
assert result["hyde_used"] is False
|
||||
# When disabled, original_query should match the query
|
||||
assert result["original_query"] == "Test query"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Hybrid Search Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestHybridSearch:
|
||||
"""Tests for Hybrid Search (Dense + Sparse) module."""
|
||||
|
||||
def test_hybrid_search_config(self):
|
||||
"""Test hybrid search configuration."""
|
||||
from hybrid_search import HYBRID_ENABLED, DENSE_WEIGHT, SPARSE_WEIGHT, get_hybrid_search_info
|
||||
|
||||
info = get_hybrid_search_info()
|
||||
assert "enabled" in info
|
||||
assert "dense_weight" in info
|
||||
assert "sparse_weight" in info
|
||||
assert info["dense_weight"] + info["sparse_weight"] == pytest.approx(1.0)
|
||||
|
||||
def test_bm25_tokenization(self):
|
||||
"""Test BM25 German tokenization."""
|
||||
from hybrid_search import BM25
|
||||
|
||||
bm25 = BM25()
|
||||
tokens = bm25._tokenize("Der Erwartungshorizont für die Abiturprüfung")
|
||||
|
||||
# German stopwords should be removed
|
||||
assert "der" not in tokens
|
||||
assert "für" not in tokens
|
||||
assert "die" not in tokens
|
||||
|
||||
# Content words should remain
|
||||
assert "erwartungshorizont" in tokens
|
||||
assert "abiturprüfung" in tokens
|
||||
|
||||
def test_bm25_stopwords(self):
|
||||
"""Test that German stopwords are defined."""
|
||||
from hybrid_search import GERMAN_STOPWORDS
|
||||
|
||||
assert "der" in GERMAN_STOPWORDS
|
||||
assert "die" in GERMAN_STOPWORDS
|
||||
assert "und" in GERMAN_STOPWORDS
|
||||
assert "ist" in GERMAN_STOPWORDS
|
||||
assert len(GERMAN_STOPWORDS) > 50
|
||||
|
||||
def test_bm25_fit_and_search(self):
|
||||
"""Test BM25 fitting and searching."""
|
||||
from hybrid_search import BM25
|
||||
|
||||
documents = [
|
||||
"Der Erwartungshorizont für Mathematik enthält Bewertungskriterien.",
|
||||
"Die Gedichtanalyse erfordert formale und inhaltliche Aspekte.",
|
||||
"Biologie Klausur Bewertung nach Anforderungsbereichen.",
|
||||
]
|
||||
|
||||
bm25 = BM25()
|
||||
bm25.fit(documents)
|
||||
|
||||
assert bm25.N == 3
|
||||
assert len(bm25.corpus) == 3
|
||||
|
||||
results = bm25.search("Mathematik Bewertungskriterien", top_k=2)
|
||||
assert len(results) == 2
|
||||
assert results[0][0] == 0 # First document should rank highest
|
||||
|
||||
def test_normalize_scores(self):
|
||||
"""Test score normalization."""
|
||||
from hybrid_search import normalize_scores
|
||||
|
||||
scores = [0.5, 1.0, 0.0, 0.75]
|
||||
normalized = normalize_scores(scores)
|
||||
|
||||
assert min(normalized) == 0.0
|
||||
assert max(normalized) == 1.0
|
||||
assert len(normalized) == len(scores)
|
||||
|
||||
def test_normalize_scores_empty(self):
|
||||
"""Test normalization with empty list."""
|
||||
from hybrid_search import normalize_scores
|
||||
|
||||
assert normalize_scores([]) == []
|
||||
|
||||
def test_normalize_scores_same_value(self):
|
||||
"""Test normalization when all scores are the same."""
|
||||
from hybrid_search import normalize_scores
|
||||
|
||||
scores = [0.5, 0.5, 0.5]
|
||||
normalized = normalize_scores(scores)
|
||||
assert all(s == 1.0 for s in normalized)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RAG Evaluation Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestRAGEvaluation:
|
||||
"""Tests for RAG Evaluation (RAGAS-inspired) module."""
|
||||
|
||||
def test_evaluation_config(self):
|
||||
"""Test RAG evaluation configuration."""
|
||||
from rag_evaluation import EVALUATION_ENABLED, EVAL_MODEL, get_evaluation_info
|
||||
|
||||
info = get_evaluation_info()
|
||||
assert "enabled" in info
|
||||
assert "metrics" in info
|
||||
assert "context_precision" in info["metrics"]
|
||||
assert "faithfulness" in info["metrics"]
|
||||
|
||||
def test_text_similarity(self):
|
||||
"""Test Jaccard text similarity."""
|
||||
from rag_evaluation import _text_similarity
|
||||
|
||||
# Same text
|
||||
sim1 = _text_similarity("Hello world", "Hello world")
|
||||
assert sim1 == 1.0
|
||||
|
||||
# No overlap
|
||||
sim2 = _text_similarity("Hello world", "Goodbye universe")
|
||||
assert sim2 == 0.0
|
||||
|
||||
# Partial overlap
|
||||
sim3 = _text_similarity("Hello beautiful world", "Hello cruel world")
|
||||
assert 0 < sim3 < 1
|
||||
|
||||
def test_context_precision(self):
|
||||
"""Test context precision calculation."""
|
||||
from rag_evaluation import calculate_context_precision
|
||||
|
||||
retrieved = ["Doc A about topic X", "Doc B about topic Y"]
|
||||
relevant = ["Doc A about topic X"]
|
||||
|
||||
precision = calculate_context_precision("query", retrieved, relevant)
|
||||
assert precision == 0.5 # 1 out of 2 retrieved is relevant
|
||||
|
||||
def test_context_precision_empty_retrieved(self):
|
||||
"""Test precision with no retrieved documents."""
|
||||
from rag_evaluation import calculate_context_precision
|
||||
|
||||
precision = calculate_context_precision("query", [], ["relevant"])
|
||||
assert precision == 0.0
|
||||
|
||||
def test_context_recall(self):
|
||||
"""Test context recall calculation."""
|
||||
from rag_evaluation import calculate_context_recall
|
||||
|
||||
retrieved = ["Doc A about topic X"]
|
||||
relevant = ["Doc A about topic X", "Doc B about topic Y"]
|
||||
|
||||
recall = calculate_context_recall("query", retrieved, relevant)
|
||||
assert recall == 0.5 # Found 1 out of 2 relevant
|
||||
|
||||
def test_context_recall_no_relevant(self):
|
||||
"""Test recall when there are no relevant documents."""
|
||||
from rag_evaluation import calculate_context_recall
|
||||
|
||||
recall = calculate_context_recall("query", ["something"], [])
|
||||
assert recall == 1.0 # Nothing to miss
|
||||
|
||||
def test_load_save_eval_results(self):
|
||||
"""Test evaluation results file operations."""
|
||||
from rag_evaluation import _load_eval_results, _save_eval_results
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# Use temp file
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
||||
temp_path = Path(f.name) # Convert to Path object
|
||||
|
||||
try:
|
||||
with patch('rag_evaluation.EVAL_RESULTS_FILE', temp_path):
|
||||
# Save some results
|
||||
test_results = [{"test": "data", "score": 0.8}]
|
||||
_save_eval_results(test_results)
|
||||
|
||||
# Load them back
|
||||
loaded = _load_eval_results()
|
||||
assert loaded == test_results
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# PDF Extraction Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestPDFExtraction:
|
||||
"""Tests for PDF Extraction module."""
|
||||
|
||||
def test_pdf_extraction_config(self):
|
||||
"""Test PDF extraction configuration."""
|
||||
from pdf_extraction import PDF_BACKEND, get_pdf_extraction_info
|
||||
|
||||
info = get_pdf_extraction_info()
|
||||
assert "configured_backend" in info
|
||||
assert "available_backends" in info
|
||||
assert "recommended" in info
|
||||
|
||||
def test_detect_available_backends(self):
|
||||
"""Test backend detection."""
|
||||
from pdf_extraction import _detect_available_backends
|
||||
|
||||
backends = _detect_available_backends()
|
||||
assert isinstance(backends, list)
|
||||
# In Docker container, at least pypdf (BSD) or unstructured (Apache 2.0) should be available
|
||||
# In local test environment without dependencies, list may be empty
|
||||
# NOTE: PyMuPDF (AGPL) is NOT installed by default for license compliance
|
||||
if backends:
|
||||
# If any backend is found, verify it's one of the license-compliant options
|
||||
for backend in backends:
|
||||
assert backend in ["pypdf", "unstructured", "pymupdf"]
|
||||
|
||||
def test_pdf_extraction_result_class(self):
|
||||
"""Test PDFExtractionResult data class."""
|
||||
from pdf_extraction import PDFExtractionResult
|
||||
|
||||
result = PDFExtractionResult(
|
||||
text="Extracted text",
|
||||
backend_used="pypdf",
|
||||
pages=5,
|
||||
elements=[{"type": "paragraph"}],
|
||||
tables=[{"text": "table data"}],
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
assert result.text == "Extracted text"
|
||||
assert result.backend_used == "pypdf"
|
||||
assert result.pages == 5
|
||||
assert len(result.elements) == 1
|
||||
assert len(result.tables) == 1
|
||||
|
||||
# Test to_dict
|
||||
d = result.to_dict()
|
||||
assert d["text"] == "Extracted text"
|
||||
assert d["element_count"] == 1
|
||||
assert d["table_count"] == 1
|
||||
|
||||
def test_pdf_extraction_error(self):
|
||||
"""Test PDF extraction error handling."""
|
||||
from pdf_extraction import PDFExtractionError
|
||||
|
||||
with pytest.raises(PDFExtractionError):
|
||||
raise PDFExtractionError("Test error")
|
||||
|
||||
@pytest.mark.xfail(reason="_extract_with_pypdf is internal function not exposed in API")
|
||||
def test_pypdf_extraction(self):
|
||||
"""Test pypdf extraction with a simple PDF (BSD-3-Clause licensed)."""
|
||||
from pdf_extraction import _extract_with_pypdf, PDFExtractionError
|
||||
|
||||
# Create a minimal valid PDF
|
||||
# This is a very simple PDF that PyPDF2 can parse
|
||||
simple_pdf = b"""%PDF-1.4
|
||||
1 0 obj
|
||||
<< /Type /Catalog /Pages 2 0 R >>
|
||||
endobj
|
||||
2 0 obj
|
||||
<< /Type /Pages /Kids [3 0 R] /Count 1 >>
|
||||
endobj
|
||||
3 0 obj
|
||||
<< /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] /Contents 4 0 R >>
|
||||
endobj
|
||||
4 0 obj
|
||||
<< /Length 44 >>
|
||||
stream
|
||||
BT
|
||||
/F1 12 Tf
|
||||
100 700 Td
|
||||
(Hello World) Tj
|
||||
ET
|
||||
endstream
|
||||
endobj
|
||||
xref
|
||||
0 5
|
||||
0000000000 65535 f
|
||||
0000000009 00000 n
|
||||
0000000058 00000 n
|
||||
0000000115 00000 n
|
||||
0000000206 00000 n
|
||||
trailer
|
||||
<< /Size 5 /Root 1 0 R >>
|
||||
startxref
|
||||
300
|
||||
%%EOF"""
|
||||
|
||||
# This may fail because the PDF is too minimal, but tests the code path
|
||||
try:
|
||||
result = _extract_with_pypdf(simple_pdf)
|
||||
assert result.backend_used == "pypdf"
|
||||
except PDFExtractionError:
|
||||
# Expected for minimal PDF
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Self-RAG Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestSelfRAG:
|
||||
"""Tests for Self-RAG / Corrective RAG module."""
|
||||
|
||||
def test_self_rag_config(self):
|
||||
"""Test Self-RAG configuration."""
|
||||
from self_rag import (
|
||||
SELF_RAG_ENABLED, RELEVANCE_THRESHOLD, GROUNDING_THRESHOLD,
|
||||
MAX_RETRIEVAL_ATTEMPTS, get_self_rag_info
|
||||
)
|
||||
|
||||
info = get_self_rag_info()
|
||||
assert "enabled" in info
|
||||
assert "relevance_threshold" in info
|
||||
assert "grounding_threshold" in info
|
||||
assert "max_retrieval_attempts" in info
|
||||
assert "features" in info
|
||||
|
||||
def test_retrieval_decision_enum(self):
|
||||
"""Test RetrievalDecision enum."""
|
||||
from self_rag import RetrievalDecision
|
||||
|
||||
assert RetrievalDecision.SUFFICIENT.value == "sufficient"
|
||||
assert RetrievalDecision.NEEDS_MORE.value == "needs_more"
|
||||
assert RetrievalDecision.REFORMULATE.value == "reformulate"
|
||||
assert RetrievalDecision.FALLBACK.value == "fallback"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grade_document_relevance_no_llm(self):
|
||||
"""Test document grading without LLM (keyword-based)."""
|
||||
from self_rag import grade_document_relevance
|
||||
|
||||
with patch('self_rag.OPENAI_API_KEY', ''):
|
||||
score, reason = await grade_document_relevance(
|
||||
"Mathematik Bewertungskriterien",
|
||||
"Der Erwartungshorizont für Mathematik enthält klare Bewertungskriterien."
|
||||
)
|
||||
|
||||
assert 0 <= score <= 1
|
||||
assert "Keyword" in reason
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_retrieval_strategy_empty_docs(self):
|
||||
"""Test retrieval decision with no documents."""
|
||||
from self_rag import decide_retrieval_strategy, RetrievalDecision
|
||||
|
||||
decision, meta = await decide_retrieval_strategy("query", [], attempt=1)
|
||||
assert decision == RetrievalDecision.REFORMULATE
|
||||
assert "No documents" in meta.get("reason", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_retrieval_strategy_max_attempts(self):
|
||||
"""Test retrieval decision at max attempts."""
|
||||
from self_rag import decide_retrieval_strategy, RetrievalDecision, MAX_RETRIEVAL_ATTEMPTS
|
||||
|
||||
decision, meta = await decide_retrieval_strategy(
|
||||
"query", [], attempt=MAX_RETRIEVAL_ATTEMPTS
|
||||
)
|
||||
assert decision == RetrievalDecision.FALLBACK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reformulate_query_no_llm(self):
|
||||
"""Test query reformulation without LLM."""
|
||||
from self_rag import reformulate_query
|
||||
|
||||
with patch('self_rag.OPENAI_API_KEY', ''):
|
||||
result = await reformulate_query("EA Mathematik Anforderungen")
|
||||
|
||||
# Should expand abbreviations
|
||||
assert "erhöhtes Anforderungsniveau" in result or result == "EA Mathematik Anforderungen"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_relevant_documents(self):
|
||||
"""Test document filtering by relevance."""
|
||||
from self_rag import filter_relevant_documents
|
||||
|
||||
docs = [
|
||||
{"text": "Mathematik Bewertungskriterien für Abitur"},
|
||||
{"text": "Rezept für Schokoladenkuchen"},
|
||||
{"text": "Erwartungshorizont Mathematik eA"},
|
||||
]
|
||||
|
||||
with patch('self_rag.OPENAI_API_KEY', ''):
|
||||
relevant, filtered = await filter_relevant_documents(
|
||||
"Mathematik Abitur Bewertung",
|
||||
docs,
|
||||
threshold=0.1 # Low threshold for keyword matching
|
||||
)
|
||||
|
||||
# All docs should have relevance_score added
|
||||
for doc in relevant + filtered:
|
||||
assert "relevance_score" in doc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_self_rag_retrieve_disabled(self):
|
||||
"""Test self_rag_retrieve when disabled."""
|
||||
from self_rag import self_rag_retrieve
|
||||
|
||||
async def mock_search(query, **kwargs):
|
||||
return [{"id": "1", "text": "Result"}]
|
||||
|
||||
with patch('self_rag.SELF_RAG_ENABLED', False):
|
||||
result = await self_rag_retrieve(
|
||||
query="Test query",
|
||||
search_func=mock_search,
|
||||
)
|
||||
|
||||
assert result["self_rag_enabled"] is False
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests - Module Availability
|
||||
# =============================================================================
|
||||
|
||||
class TestModuleAvailability:
|
||||
"""Test that all advanced RAG modules are properly importable."""
|
||||
|
||||
def test_hyde_import(self):
|
||||
"""Test HyDE module import."""
|
||||
from hyde import (
|
||||
generate_hypothetical_document,
|
||||
hyde_search,
|
||||
get_hyde_info,
|
||||
HYDE_ENABLED,
|
||||
)
|
||||
assert callable(generate_hypothetical_document)
|
||||
assert callable(hyde_search)
|
||||
|
||||
def test_hybrid_search_import(self):
|
||||
"""Test Hybrid Search module import."""
|
||||
from hybrid_search import (
|
||||
BM25,
|
||||
hybrid_search,
|
||||
get_hybrid_search_info,
|
||||
HYBRID_ENABLED,
|
||||
)
|
||||
assert callable(hybrid_search)
|
||||
assert BM25 is not None
|
||||
|
||||
def test_rag_evaluation_import(self):
|
||||
"""Test RAG Evaluation module import."""
|
||||
from rag_evaluation import (
|
||||
calculate_context_precision,
|
||||
calculate_context_recall,
|
||||
evaluate_faithfulness,
|
||||
evaluate_answer_relevancy,
|
||||
evaluate_rag_response,
|
||||
get_evaluation_info,
|
||||
)
|
||||
assert callable(calculate_context_precision)
|
||||
assert callable(evaluate_rag_response)
|
||||
|
||||
def test_pdf_extraction_import(self):
|
||||
"""Test PDF Extraction module import."""
|
||||
from pdf_extraction import (
|
||||
extract_text_from_pdf,
|
||||
extract_text_from_pdf_enhanced,
|
||||
get_pdf_extraction_info,
|
||||
PDFExtractionResult,
|
||||
)
|
||||
assert callable(extract_text_from_pdf)
|
||||
assert callable(extract_text_from_pdf_enhanced)
|
||||
|
||||
def test_self_rag_import(self):
|
||||
"""Test Self-RAG module import."""
|
||||
from self_rag import (
|
||||
grade_document_relevance,
|
||||
filter_relevant_documents,
|
||||
self_rag_retrieve,
|
||||
get_self_rag_info,
|
||||
RetrievalDecision,
|
||||
)
|
||||
assert callable(self_rag_retrieve)
|
||||
assert RetrievalDecision is not None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# End-to-End Feature Verification
|
||||
# =============================================================================
|
||||
|
||||
class TestFeatureVerification:
|
||||
"""Verify that all features are properly configured and usable."""
|
||||
|
||||
def test_all_features_have_info_endpoints(self):
|
||||
"""Test that all features provide info functions."""
|
||||
from hyde import get_hyde_info
|
||||
from hybrid_search import get_hybrid_search_info
|
||||
from rag_evaluation import get_evaluation_info
|
||||
from pdf_extraction import get_pdf_extraction_info
|
||||
from self_rag import get_self_rag_info
|
||||
|
||||
infos = [
|
||||
get_hyde_info(),
|
||||
get_hybrid_search_info(),
|
||||
get_evaluation_info(),
|
||||
get_pdf_extraction_info(),
|
||||
get_self_rag_info(),
|
||||
]
|
||||
|
||||
for info in infos:
|
||||
assert isinstance(info, dict)
|
||||
# Each should have an "enabled" or similar status field
|
||||
assert any(k in info for k in ["enabled", "configured_backend", "available_backends"])
|
||||
|
||||
def test_environment_variables_documented(self):
|
||||
"""Test that all environment variables are accessible."""
|
||||
import os
|
||||
|
||||
# These env vars should be used by the modules
|
||||
env_vars = [
|
||||
"HYDE_ENABLED",
|
||||
"HYDE_LLM_BACKEND",
|
||||
"HYBRID_SEARCH_ENABLED",
|
||||
"HYBRID_DENSE_WEIGHT",
|
||||
"RAG_EVALUATION_ENABLED",
|
||||
"PDF_EXTRACTION_BACKEND",
|
||||
"SELF_RAG_ENABLED",
|
||||
]
|
||||
|
||||
# Just verify they're readable (will use defaults if not set)
|
||||
for var in env_vars:
|
||||
os.getenv(var, "default") # Should not raise
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Admin API Tests (RAG Documentation with HTML rendering)
|
||||
# =============================================================================
|
||||
|
||||
class TestRAGAdminAPI:
|
||||
"""Tests for RAG Admin API endpoints."""
|
||||
|
||||
@pytest.mark.xfail(reason="get_rag_documentation not yet implemented - Backlog item")
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_documentation_markdown_format(self):
|
||||
"""Test RAG documentation endpoint returns markdown."""
|
||||
from admin_api import get_rag_documentation
|
||||
|
||||
result = await get_rag_documentation(format="markdown")
|
||||
|
||||
assert result["format"] == "markdown"
|
||||
assert "content" in result
|
||||
assert result["status"] in ["success", "inline"]
|
||||
|
||||
@pytest.mark.xfail(reason="get_rag_documentation not yet implemented - Backlog item")
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_documentation_html_format(self):
|
||||
"""Test RAG documentation endpoint returns HTML with tables."""
|
||||
from admin_api import get_rag_documentation
|
||||
|
||||
result = await get_rag_documentation(format="html")
|
||||
|
||||
assert result["format"] == "html"
|
||||
assert "content" in result
|
||||
|
||||
# HTML should contain proper table styling
|
||||
html = result["content"]
|
||||
assert "<table" in html
|
||||
assert "<th>" in html or "<td>" in html
|
||||
assert "<style>" in html
|
||||
assert "border-collapse" in html
|
||||
|
||||
@pytest.mark.xfail(reason="get_rag_system_info not yet implemented - Backlog item")
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_system_info_has_feature_status(self):
|
||||
"""Test RAG system-info includes feature status."""
|
||||
from admin_api import get_rag_system_info
|
||||
|
||||
result = await get_rag_system_info()
|
||||
|
||||
# Check feature status structure
|
||||
assert "feature_status" in result.__dict__ or hasattr(result, 'feature_status')
|
||||
|
||||
@pytest.mark.xfail(reason="get_rag_system_info not yet implemented - Backlog item")
|
||||
@pytest.mark.asyncio
|
||||
async def test_rag_system_info_has_privacy_notes(self):
|
||||
"""Test RAG system-info includes privacy notes."""
|
||||
from admin_api import get_rag_system_info
|
||||
|
||||
result = await get_rag_system_info()
|
||||
|
||||
assert hasattr(result, 'privacy_notes') or "privacy_notes" in str(result)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Reranker Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestReranker:
|
||||
"""Tests for Re-Ranker module."""
|
||||
|
||||
def test_reranker_config(self):
|
||||
"""Test reranker configuration."""
|
||||
from reranker import RERANKER_BACKEND, LOCAL_RERANKER_MODEL, get_reranker_info
|
||||
|
||||
info = get_reranker_info()
|
||||
assert "backend" in info
|
||||
assert "model" in info
|
||||
# Note: key is "embedding_service_available" not "available"
|
||||
assert "embedding_service_available" in info
|
||||
|
||||
def test_reranker_model_license(self):
|
||||
"""Test that default reranker model is Apache 2.0 licensed."""
|
||||
from reranker import LOCAL_RERANKER_MODEL
|
||||
|
||||
# BAAI/bge-reranker-v2-m3 is Apache 2.0 licensed
|
||||
assert "bge-reranker" in LOCAL_RERANKER_MODEL
|
||||
# Should NOT use MS MARCO models (non-commercial license)
|
||||
assert "ms-marco" not in LOCAL_RERANKER_MODEL.lower()
|
||||
|
||||
@pytest.mark.xfail(reason="rerank_results not yet implemented - async wrapper planned")
|
||||
@pytest.mark.asyncio
|
||||
async def test_rerank_results(self):
|
||||
"""Test reranking of results."""
|
||||
from reranker import rerank_results
|
||||
|
||||
results = [
|
||||
{"text": "Mathematik Bewertungskriterien", "score": 0.5},
|
||||
{"text": "Rezept fuer Kuchen", "score": 0.6},
|
||||
{"text": "Erwartungshorizont Mathe Abitur", "score": 0.4},
|
||||
]
|
||||
|
||||
reranked = await rerank_results(
|
||||
query="Mathematik Abitur Bewertung",
|
||||
results=results,
|
||||
top_k=2,
|
||||
)
|
||||
|
||||
# Should return top_k results
|
||||
assert len(reranked) <= 2
|
||||
|
||||
# Results should have rerank score
|
||||
for r in reranked:
|
||||
assert "rerank_score" in r or "score" in r
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Integration Tests (EH RAG Query with rerank param)
|
||||
# =============================================================================
|
||||
|
||||
class TestEHRAGQueryAPI:
|
||||
"""Tests for EH RAG Query API with advanced features."""
|
||||
|
||||
def test_eh_rag_query_params(self):
|
||||
"""Test that RAG query accepts rerank parameter."""
|
||||
# This is a structural test - verify the API model accepts the param
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
# Mock the expected request model
|
||||
class RAGQueryRequest(BaseModel):
|
||||
query_text: str
|
||||
passphrase: str
|
||||
subject: Optional[str] = None
|
||||
limit: int = 5
|
||||
rerank: bool = True # New param
|
||||
|
||||
# Should not raise validation error
|
||||
req = RAGQueryRequest(
|
||||
query_text="Test",
|
||||
passphrase="secret",
|
||||
rerank=True,
|
||||
)
|
||||
assert req.rerank is True
|
||||
|
||||
def test_rag_result_has_search_info(self):
|
||||
"""Test that RAG result model supports search_info."""
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
# Mock the expected response model
|
||||
class RAGSource(BaseModel):
|
||||
text: str
|
||||
score: float
|
||||
reranked: Optional[bool] = None
|
||||
|
||||
class RAGResult(BaseModel):
|
||||
context: str
|
||||
sources: List[RAGSource]
|
||||
query: str
|
||||
search_info: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Should support search_info
|
||||
result = RAGResult(
|
||||
context="Test context",
|
||||
sources=[RAGSource(text="Test", score=0.8, reranked=True)],
|
||||
query="Test query",
|
||||
search_info={
|
||||
"rerank_applied": True,
|
||||
"hybrid_search_applied": True,
|
||||
"total_candidates": 20,
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
}
|
||||
)
|
||||
|
||||
assert result.search_info["rerank_applied"] is True
|
||||
assert result.sources[0].reranked is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Run Tests
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
937
klausur-service/backend/tests/test_byoeh.py
Normal file
937
klausur-service/backend/tests/test_byoeh.py
Normal file
@@ -0,0 +1,937 @@
|
||||
"""
|
||||
Unit Tests for BYOEH (Bring-Your-Own-Expectation-Horizon) Module
|
||||
|
||||
Tests cover:
|
||||
- EH upload and storage
|
||||
- Key sharing system
|
||||
- Invitation flow (Invite, Accept, Decline, Revoke)
|
||||
- Klausur linking
|
||||
- RAG query functionality
|
||||
- Audit logging
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import uuid
|
||||
import hashlib
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Import the main app and data structures (post-refactoring modular imports)
|
||||
import sys
|
||||
sys.path.insert(0, '..')
|
||||
from main import app
|
||||
from storage import (
|
||||
eh_db,
|
||||
eh_key_shares_db,
|
||||
eh_klausur_links_db,
|
||||
eh_audit_db,
|
||||
eh_invitations_db,
|
||||
klausuren_db,
|
||||
)
|
||||
from models.eh import (
|
||||
Erwartungshorizont,
|
||||
EHKeyShare,
|
||||
EHKlausurLink,
|
||||
EHShareInvitation,
|
||||
)
|
||||
from models.exam import Klausur
|
||||
from models.enums import KlausurModus
|
||||
|
||||
|
||||
# =============================================
|
||||
# FIXTURES
|
||||
# =============================================
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client for FastAPI app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers():
|
||||
"""JWT auth headers for teacher user."""
|
||||
import jwt
|
||||
token = jwt.encode(
|
||||
{
|
||||
"user_id": "test-teacher-001",
|
||||
"email": "teacher@school.de",
|
||||
"role": "admin",
|
||||
"tenant_id": "school-001"
|
||||
},
|
||||
"your-super-secret-jwt-key-change-in-production",
|
||||
algorithm="HS256"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def second_examiner_headers():
|
||||
"""JWT auth headers for second examiner (non-admin teacher)."""
|
||||
import jwt
|
||||
token = jwt.encode(
|
||||
{
|
||||
"user_id": "test-examiner-002",
|
||||
"email": "examiner2@school.de",
|
||||
"role": "teacher", # Non-admin to test access control
|
||||
"tenant_id": "school-001"
|
||||
},
|
||||
"your-super-secret-jwt-key-change-in-production",
|
||||
algorithm="HS256"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_eh():
|
||||
"""Create a sample Erwartungshorizont."""
|
||||
eh_id = str(uuid.uuid4())
|
||||
eh = Erwartungshorizont(
|
||||
id=eh_id,
|
||||
tenant_id="school-001",
|
||||
teacher_id="test-teacher-001",
|
||||
title="Deutsch LK Abitur 2025",
|
||||
subject="deutsch",
|
||||
niveau="eA",
|
||||
year=2025,
|
||||
aufgaben_nummer="Aufgabe 1",
|
||||
encryption_key_hash="abc123" + "0" * 58, # 64 char hash
|
||||
salt="def456" * 5 + "0" * 2, # 32 char salt
|
||||
encrypted_file_path=f"/app/eh-uploads/school-001/{eh_id}/encrypted.bin",
|
||||
file_size_bytes=1024000,
|
||||
original_filename="erwartungshorizont.pdf",
|
||||
rights_confirmed=True,
|
||||
rights_confirmed_at=datetime.now(timezone.utc),
|
||||
status="indexed",
|
||||
chunk_count=10,
|
||||
indexed_at=datetime.now(timezone.utc),
|
||||
error_message=None,
|
||||
training_allowed=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
deleted_at=None
|
||||
)
|
||||
eh_db[eh_id] = eh
|
||||
return eh
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_klausur():
|
||||
"""Create a sample Klausur."""
|
||||
klausur_id = str(uuid.uuid4())
|
||||
klausur = Klausur(
|
||||
id=klausur_id,
|
||||
title="Deutsch LK Q1",
|
||||
subject="deutsch",
|
||||
modus=KlausurModus.VORABITUR,
|
||||
class_id="class-001",
|
||||
year=2025,
|
||||
semester="Q1",
|
||||
erwartungshorizont=None,
|
||||
students=[],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
teacher_id="test-teacher-001"
|
||||
)
|
||||
klausuren_db[klausur_id] = klausur
|
||||
return klausur
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup():
|
||||
"""Clean up databases after each test."""
|
||||
yield
|
||||
eh_db.clear()
|
||||
eh_key_shares_db.clear()
|
||||
eh_klausur_links_db.clear()
|
||||
eh_audit_db.clear()
|
||||
eh_invitations_db.clear()
|
||||
klausuren_db.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_invitation(sample_eh):
|
||||
"""Create a sample invitation."""
|
||||
invitation_id = str(uuid.uuid4())
|
||||
invitation = EHShareInvitation(
|
||||
id=invitation_id,
|
||||
eh_id=sample_eh.id,
|
||||
inviter_id="test-teacher-001",
|
||||
invitee_id="",
|
||||
invitee_email="examiner2@school.de",
|
||||
role="second_examiner",
|
||||
klausur_id=None,
|
||||
message="Bitte EH fuer Zweitkorrektur nutzen",
|
||||
status="pending",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=14),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
accepted_at=None,
|
||||
declined_at=None
|
||||
)
|
||||
eh_invitations_db[invitation_id] = invitation
|
||||
return invitation
|
||||
|
||||
|
||||
# =============================================
|
||||
# EH CRUD TESTS
|
||||
# =============================================
|
||||
|
||||
class TestEHList:
|
||||
"""Tests for GET /api/v1/eh"""
|
||||
|
||||
def test_list_empty(self, client, auth_headers):
|
||||
"""List returns empty when no EH exist."""
|
||||
response = client.get("/api/v1/eh", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_list_own_eh(self, client, auth_headers, sample_eh):
|
||||
"""List returns only user's own EH."""
|
||||
response = client.get("/api/v1/eh", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == sample_eh.id
|
||||
assert data[0]["title"] == sample_eh.title
|
||||
|
||||
def test_list_filter_by_subject(self, client, auth_headers, sample_eh):
|
||||
"""List can filter by subject."""
|
||||
response = client.get("/api/v1/eh?subject=deutsch", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
|
||||
response = client.get("/api/v1/eh?subject=englisch", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
def test_list_filter_by_year(self, client, auth_headers, sample_eh):
|
||||
"""List can filter by year."""
|
||||
response = client.get("/api/v1/eh?year=2025", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
|
||||
response = client.get("/api/v1/eh?year=2024", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 0
|
||||
|
||||
|
||||
class TestEHGet:
|
||||
"""Tests for GET /api/v1/eh/{id}"""
|
||||
|
||||
def test_get_existing_eh(self, client, auth_headers, sample_eh):
|
||||
"""Get returns EH details."""
|
||||
response = client.get(f"/api/v1/eh/{sample_eh.id}", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == sample_eh.id
|
||||
assert data["title"] == sample_eh.title
|
||||
assert data["subject"] == sample_eh.subject
|
||||
|
||||
def test_get_nonexistent_eh(self, client, auth_headers):
|
||||
"""Get returns 404 for non-existent EH."""
|
||||
response = client.get(f"/api/v1/eh/{uuid.uuid4()}", headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestEHDelete:
|
||||
"""Tests for DELETE /api/v1/eh/{id}"""
|
||||
|
||||
def test_delete_own_eh(self, client, auth_headers, sample_eh):
|
||||
"""Owner can delete their EH."""
|
||||
response = client.delete(f"/api/v1/eh/{sample_eh.id}", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "deleted"
|
||||
|
||||
# Verify soft delete
|
||||
assert eh_db[sample_eh.id].deleted_at is not None
|
||||
|
||||
def test_delete_others_eh(self, client, second_examiner_headers, sample_eh):
|
||||
"""Non-owner cannot delete EH."""
|
||||
response = client.delete(f"/api/v1/eh/{sample_eh.id}", headers=second_examiner_headers)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_delete_nonexistent_eh(self, client, auth_headers):
|
||||
"""Delete returns 404 for non-existent EH."""
|
||||
response = client.delete(f"/api/v1/eh/{uuid.uuid4()}", headers=auth_headers)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# =============================================
|
||||
# KEY SHARING TESTS
|
||||
# =============================================
|
||||
|
||||
class TestEHSharing:
|
||||
"""Tests for EH key sharing system."""
|
||||
|
||||
def test_share_eh_with_examiner(self, client, auth_headers, sample_eh):
|
||||
"""Owner can share EH with another examiner."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/share",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"user_id": "test-examiner-002",
|
||||
"role": "second_examiner",
|
||||
"encrypted_passphrase": "encrypted-secret-123",
|
||||
"passphrase_hint": "Das uebliche Passwort"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "shared"
|
||||
assert data["shared_with"] == "test-examiner-002"
|
||||
assert data["role"] == "second_examiner"
|
||||
|
||||
def test_share_invalid_role(self, client, auth_headers, sample_eh):
|
||||
"""Sharing with invalid role fails."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/share",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"user_id": "test-examiner-002",
|
||||
"role": "invalid_role",
|
||||
"encrypted_passphrase": "encrypted-secret-123"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_share_others_eh_fails(self, client, second_examiner_headers, sample_eh):
|
||||
"""Non-owner cannot share EH."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/share",
|
||||
headers=second_examiner_headers,
|
||||
json={
|
||||
"user_id": "test-examiner-003",
|
||||
"role": "third_examiner",
|
||||
"encrypted_passphrase": "encrypted-secret-123"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_list_shares(self, client, auth_headers, sample_eh):
|
||||
"""Owner can list shares."""
|
||||
# Create a share first
|
||||
share = EHKeyShare(
|
||||
id=str(uuid.uuid4()),
|
||||
eh_id=sample_eh.id,
|
||||
user_id="test-examiner-002",
|
||||
encrypted_passphrase="encrypted",
|
||||
passphrase_hint="hint",
|
||||
granted_by="test-teacher-001",
|
||||
granted_at=datetime.now(timezone.utc),
|
||||
role="second_examiner",
|
||||
klausur_id=None,
|
||||
active=True
|
||||
)
|
||||
eh_key_shares_db[sample_eh.id] = [share]
|
||||
|
||||
response = client.get(f"/api/v1/eh/{sample_eh.id}/shares", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["user_id"] == "test-examiner-002"
|
||||
|
||||
def test_revoke_share(self, client, auth_headers, sample_eh):
|
||||
"""Owner can revoke a share."""
|
||||
share_id = str(uuid.uuid4())
|
||||
share = EHKeyShare(
|
||||
id=share_id,
|
||||
eh_id=sample_eh.id,
|
||||
user_id="test-examiner-002",
|
||||
encrypted_passphrase="encrypted",
|
||||
passphrase_hint="hint",
|
||||
granted_by="test-teacher-001",
|
||||
granted_at=datetime.now(timezone.utc),
|
||||
role="second_examiner",
|
||||
klausur_id=None,
|
||||
active=True
|
||||
)
|
||||
eh_key_shares_db[sample_eh.id] = [share]
|
||||
|
||||
response = client.delete(
|
||||
f"/api/v1/eh/{sample_eh.id}/shares/{share_id}",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "revoked"
|
||||
|
||||
# Verify share is inactive
|
||||
assert not eh_key_shares_db[sample_eh.id][0].active
|
||||
|
||||
def test_get_shared_with_me(self, client, second_examiner_headers, sample_eh):
|
||||
"""User can see EH shared with them."""
|
||||
share = EHKeyShare(
|
||||
id=str(uuid.uuid4()),
|
||||
eh_id=sample_eh.id,
|
||||
user_id="test-examiner-002",
|
||||
encrypted_passphrase="encrypted",
|
||||
passphrase_hint="hint",
|
||||
granted_by="test-teacher-001",
|
||||
granted_at=datetime.now(timezone.utc),
|
||||
role="second_examiner",
|
||||
klausur_id=None,
|
||||
active=True
|
||||
)
|
||||
eh_key_shares_db[sample_eh.id] = [share]
|
||||
|
||||
response = client.get("/api/v1/eh/shared-with-me", headers=second_examiner_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["eh"]["id"] == sample_eh.id
|
||||
|
||||
|
||||
# =============================================
|
||||
# KLAUSUR LINKING TESTS
|
||||
# =============================================
|
||||
|
||||
class TestEHKlausurLinking:
|
||||
"""Tests for EH-Klausur linking."""
|
||||
|
||||
def test_link_eh_to_klausur(self, client, auth_headers, sample_eh, sample_klausur):
|
||||
"""Owner can link EH to Klausur."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/link-klausur",
|
||||
headers=auth_headers,
|
||||
json={"klausur_id": sample_klausur.id}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "linked"
|
||||
assert data["eh_id"] == sample_eh.id
|
||||
assert data["klausur_id"] == sample_klausur.id
|
||||
|
||||
def test_get_linked_eh(self, client, auth_headers, sample_eh, sample_klausur):
|
||||
"""Get linked EH for a Klausur."""
|
||||
link = EHKlausurLink(
|
||||
id=str(uuid.uuid4()),
|
||||
eh_id=sample_eh.id,
|
||||
klausur_id=sample_klausur.id,
|
||||
linked_by="test-teacher-001",
|
||||
linked_at=datetime.now(timezone.utc)
|
||||
)
|
||||
eh_klausur_links_db[sample_klausur.id] = [link]
|
||||
|
||||
response = client.get(
|
||||
f"/api/v1/klausuren/{sample_klausur.id}/linked-eh",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["eh"]["id"] == sample_eh.id
|
||||
assert data[0]["is_owner"] is True
|
||||
|
||||
def test_unlink_eh_from_klausur(self, client, auth_headers, sample_eh, sample_klausur):
|
||||
"""Owner can unlink EH from Klausur."""
|
||||
link = EHKlausurLink(
|
||||
id=str(uuid.uuid4()),
|
||||
eh_id=sample_eh.id,
|
||||
klausur_id=sample_klausur.id,
|
||||
linked_by="test-teacher-001",
|
||||
linked_at=datetime.now(timezone.utc)
|
||||
)
|
||||
eh_klausur_links_db[sample_klausur.id] = [link]
|
||||
|
||||
response = client.delete(
|
||||
f"/api/v1/eh/{sample_eh.id}/link-klausur/{sample_klausur.id}",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "unlinked"
|
||||
|
||||
|
||||
# =============================================
|
||||
# AUDIT LOG TESTS
|
||||
# =============================================
|
||||
|
||||
class TestAuditLog:
|
||||
"""Tests for audit logging."""
|
||||
|
||||
def test_audit_log_on_share(self, client, auth_headers, sample_eh):
|
||||
"""Sharing creates audit log entry."""
|
||||
initial_count = len(eh_audit_db)
|
||||
|
||||
client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/share",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"user_id": "test-examiner-002",
|
||||
"role": "second_examiner",
|
||||
"encrypted_passphrase": "encrypted-secret-123"
|
||||
}
|
||||
)
|
||||
|
||||
assert len(eh_audit_db) > initial_count
|
||||
latest = eh_audit_db[-1]
|
||||
assert latest.action == "share"
|
||||
assert latest.eh_id == sample_eh.id
|
||||
|
||||
def test_get_audit_log(self, client, auth_headers, sample_eh):
|
||||
"""Can retrieve audit log."""
|
||||
# Create some audit entries by sharing
|
||||
client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/share",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"user_id": "test-examiner-002",
|
||||
"role": "second_examiner",
|
||||
"encrypted_passphrase": "encrypted-secret-123"
|
||||
}
|
||||
)
|
||||
|
||||
response = client.get(
|
||||
f"/api/v1/eh/audit-log?eh_id={sample_eh.id}",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
# =============================================
|
||||
# RIGHTS TEXT TESTS
|
||||
# =============================================
|
||||
|
||||
class TestRightsText:
|
||||
"""Tests for rights confirmation text."""
|
||||
|
||||
def test_get_rights_text(self, client, auth_headers):
|
||||
"""Can retrieve rights confirmation text."""
|
||||
response = client.get("/api/v1/eh/rights-text", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "text" in data
|
||||
assert "version" in data
|
||||
assert "Urheberrecht" in data["text"]
|
||||
|
||||
|
||||
# =============================================
|
||||
# ENCRYPTION SERVICE TESTS
|
||||
# =============================================
|
||||
|
||||
class TestEncryptionUtils:
|
||||
"""Tests for encryption utilities (eh_pipeline.py)."""
|
||||
|
||||
def test_hash_key(self):
|
||||
"""Key hashing produces consistent results."""
|
||||
from eh_pipeline import hash_key
|
||||
import os
|
||||
|
||||
passphrase = "test-secret-passphrase"
|
||||
salt_hex = os.urandom(16).hex()
|
||||
hash1 = hash_key(passphrase, salt_hex)
|
||||
hash2 = hash_key(passphrase, salt_hex)
|
||||
|
||||
assert hash1 == hash2
|
||||
assert len(hash1) == 64 # SHA-256 hex
|
||||
|
||||
def test_verify_key_hash(self):
|
||||
"""Key hash verification works correctly."""
|
||||
from eh_pipeline import hash_key, verify_key_hash
|
||||
import os
|
||||
|
||||
passphrase = "test-secret-passphrase"
|
||||
salt_hex = os.urandom(16).hex()
|
||||
key_hash = hash_key(passphrase, salt_hex)
|
||||
|
||||
assert verify_key_hash(passphrase, salt_hex, key_hash) is True
|
||||
assert verify_key_hash("wrong-passphrase", salt_hex, key_hash) is False
|
||||
|
||||
def test_chunk_text(self):
|
||||
"""Text chunking produces correct overlap."""
|
||||
from eh_pipeline import chunk_text
|
||||
|
||||
text = "A" * 2000 # 2000 characters
|
||||
chunks = chunk_text(text, chunk_size=1000, overlap=200)
|
||||
|
||||
assert len(chunks) >= 2
|
||||
# Check overlap
|
||||
assert chunks[0][-200:] == chunks[1][:200]
|
||||
|
||||
def test_encrypt_decrypt_text(self):
|
||||
"""Text encryption and decryption round-trip."""
|
||||
from eh_pipeline import encrypt_text, decrypt_text
|
||||
|
||||
plaintext = "Dies ist ein geheimer Text."
|
||||
passphrase = "geheim123"
|
||||
salt = "a" * 32 # 32 hex chars = 16 bytes
|
||||
|
||||
encrypted = encrypt_text(plaintext, passphrase, salt)
|
||||
decrypted = decrypt_text(encrypted, passphrase, salt)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
|
||||
# =============================================
|
||||
# INVITATION FLOW TESTS
|
||||
# =============================================
|
||||
|
||||
class TestEHInvitationFlow:
|
||||
"""Tests for the Invite/Accept/Decline/Revoke workflow."""
|
||||
|
||||
def test_invite_to_eh(self, client, auth_headers, sample_eh):
|
||||
"""Owner can send invitation to share EH."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/invite",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"invitee_email": "zweitkorrektor@school.de",
|
||||
"role": "second_examiner",
|
||||
"message": "Bitte EH fuer Zweitkorrektur nutzen",
|
||||
"expires_in_days": 14
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "invited"
|
||||
assert data["invitee_email"] == "zweitkorrektor@school.de"
|
||||
assert data["role"] == "second_examiner"
|
||||
assert "invitation_id" in data
|
||||
assert "expires_at" in data
|
||||
|
||||
def test_invite_invalid_role(self, client, auth_headers, sample_eh):
|
||||
"""Invitation with invalid role fails."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/invite",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"invitee_email": "zweitkorrektor@school.de",
|
||||
"role": "invalid_role"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_invite_by_non_owner_fails(self, client, second_examiner_headers, sample_eh):
|
||||
"""Non-owner cannot send invitation."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/invite",
|
||||
headers=second_examiner_headers,
|
||||
json={
|
||||
"invitee_email": "drittkorrektor@school.de",
|
||||
"role": "third_examiner"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_duplicate_pending_invitation_fails(self, client, auth_headers, sample_eh, sample_invitation):
|
||||
"""Cannot send duplicate pending invitation to same user."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/invite",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"invitee_email": "examiner2@school.de", # Same email as sample_invitation
|
||||
"role": "second_examiner"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 409 # Conflict
|
||||
|
||||
def test_list_pending_invitations(self, client, second_examiner_headers, sample_eh, sample_invitation):
|
||||
"""User can see pending invitations addressed to them."""
|
||||
response = client.get("/api/v1/eh/invitations/pending", headers=second_examiner_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["invitation"]["id"] == sample_invitation.id
|
||||
assert data[0]["eh"]["title"] == sample_eh.title
|
||||
|
||||
def test_list_sent_invitations(self, client, auth_headers, sample_eh, sample_invitation):
|
||||
"""Inviter can see sent invitations."""
|
||||
response = client.get("/api/v1/eh/invitations/sent", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["invitation"]["id"] == sample_invitation.id
|
||||
|
||||
def test_accept_invitation(self, client, second_examiner_headers, sample_eh, sample_invitation):
|
||||
"""Invitee can accept invitation and get access."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/invitations/{sample_invitation.id}/accept",
|
||||
headers=second_examiner_headers,
|
||||
json={"encrypted_passphrase": "encrypted-secret-key-for-zk"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "accepted"
|
||||
assert data["eh_id"] == sample_eh.id
|
||||
assert "share_id" in data
|
||||
|
||||
# Verify invitation status updated
|
||||
assert eh_invitations_db[sample_invitation.id].status == "accepted"
|
||||
|
||||
# Verify key share created
|
||||
assert sample_eh.id in eh_key_shares_db
|
||||
assert len(eh_key_shares_db[sample_eh.id]) == 1
|
||||
|
||||
def test_accept_invitation_wrong_user(self, client, auth_headers, sample_eh, sample_invitation):
|
||||
"""Only invitee can accept invitation."""
|
||||
# auth_headers is for teacher, not the invitee
|
||||
response = client.post(
|
||||
f"/api/v1/eh/invitations/{sample_invitation.id}/accept",
|
||||
headers=auth_headers,
|
||||
json={"encrypted_passphrase": "encrypted-secret"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_accept_expired_invitation(self, client, second_examiner_headers, sample_eh):
|
||||
"""Cannot accept expired invitation."""
|
||||
# Create expired invitation
|
||||
invitation_id = str(uuid.uuid4())
|
||||
expired_invitation = EHShareInvitation(
|
||||
id=invitation_id,
|
||||
eh_id=sample_eh.id,
|
||||
inviter_id="test-teacher-001",
|
||||
invitee_id="",
|
||||
invitee_email="examiner2@school.de",
|
||||
role="second_examiner",
|
||||
klausur_id=None,
|
||||
message=None,
|
||||
status="pending",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1), # Expired
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=15),
|
||||
accepted_at=None,
|
||||
declined_at=None
|
||||
)
|
||||
eh_invitations_db[invitation_id] = expired_invitation
|
||||
|
||||
response = client.post(
|
||||
f"/api/v1/eh/invitations/{invitation_id}/accept",
|
||||
headers=second_examiner_headers,
|
||||
json={"encrypted_passphrase": "encrypted-secret"}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert "expired" in response.json()["detail"].lower()
|
||||
|
||||
def test_decline_invitation(self, client, second_examiner_headers, sample_invitation):
|
||||
"""Invitee can decline invitation."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/invitations/{sample_invitation.id}/decline",
|
||||
headers=second_examiner_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "declined"
|
||||
|
||||
# Verify status updated
|
||||
assert eh_invitations_db[sample_invitation.id].status == "declined"
|
||||
|
||||
def test_decline_invitation_wrong_user(self, client, auth_headers, sample_invitation):
|
||||
"""Only invitee can decline invitation."""
|
||||
response = client.post(
|
||||
f"/api/v1/eh/invitations/{sample_invitation.id}/decline",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_revoke_invitation(self, client, auth_headers, sample_invitation):
|
||||
"""Inviter can revoke pending invitation."""
|
||||
response = client.delete(
|
||||
f"/api/v1/eh/invitations/{sample_invitation.id}",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "revoked"
|
||||
|
||||
# Verify status updated
|
||||
assert eh_invitations_db[sample_invitation.id].status == "revoked"
|
||||
|
||||
def test_revoke_invitation_wrong_user(self, client, second_examiner_headers, sample_invitation):
|
||||
"""Only inviter can revoke invitation."""
|
||||
response = client.delete(
|
||||
f"/api/v1/eh/invitations/{sample_invitation.id}",
|
||||
headers=second_examiner_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_revoke_non_pending_invitation(self, client, auth_headers, sample_eh):
|
||||
"""Cannot revoke already accepted invitation."""
|
||||
invitation_id = str(uuid.uuid4())
|
||||
accepted_invitation = EHShareInvitation(
|
||||
id=invitation_id,
|
||||
eh_id=sample_eh.id,
|
||||
inviter_id="test-teacher-001",
|
||||
invitee_id="test-examiner-002",
|
||||
invitee_email="examiner2@school.de",
|
||||
role="second_examiner",
|
||||
klausur_id=None,
|
||||
message=None,
|
||||
status="accepted",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(days=14),
|
||||
created_at=datetime.now(timezone.utc),
|
||||
accepted_at=datetime.now(timezone.utc),
|
||||
declined_at=None
|
||||
)
|
||||
eh_invitations_db[invitation_id] = accepted_invitation
|
||||
|
||||
response = client.delete(
|
||||
f"/api/v1/eh/invitations/{invitation_id}",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_get_access_chain(self, client, auth_headers, sample_eh, sample_invitation):
|
||||
"""Owner can see complete access chain."""
|
||||
# Add a key share
|
||||
share = EHKeyShare(
|
||||
id=str(uuid.uuid4()),
|
||||
eh_id=sample_eh.id,
|
||||
user_id="test-examiner-003",
|
||||
encrypted_passphrase="encrypted",
|
||||
passphrase_hint="",
|
||||
granted_by="test-teacher-001",
|
||||
granted_at=datetime.now(timezone.utc),
|
||||
role="third_examiner",
|
||||
klausur_id=None,
|
||||
active=True
|
||||
)
|
||||
eh_key_shares_db[sample_eh.id] = [share]
|
||||
|
||||
response = client.get(f"/api/v1/eh/{sample_eh.id}/access-chain", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["eh_id"] == sample_eh.id
|
||||
assert data["owner"]["user_id"] == "test-teacher-001"
|
||||
assert len(data["active_shares"]) == 1
|
||||
assert len(data["pending_invitations"]) == 1 # sample_invitation
|
||||
|
||||
|
||||
class TestInvitationWorkflow:
|
||||
"""Integration tests for complete invitation workflow."""
|
||||
|
||||
def test_complete_invite_accept_workflow(
|
||||
self, client, auth_headers, second_examiner_headers, sample_eh, sample_klausur
|
||||
):
|
||||
"""Test complete workflow: invite -> accept -> access."""
|
||||
# 1. Owner invites ZK
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/invite",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"invitee_email": "examiner2@school.de",
|
||||
"role": "second_examiner",
|
||||
"klausur_id": sample_klausur.id,
|
||||
"message": "Bitte fuer Zweitkorrektur nutzen"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
invitation_id = response.json()["invitation_id"]
|
||||
|
||||
# 2. ZK sees pending invitation
|
||||
response = client.get("/api/v1/eh/invitations/pending", headers=second_examiner_headers)
|
||||
assert response.status_code == 200
|
||||
pending = response.json()
|
||||
assert len(pending) == 1
|
||||
assert pending[0]["invitation"]["id"] == invitation_id
|
||||
|
||||
# 3. ZK accepts invitation
|
||||
response = client.post(
|
||||
f"/api/v1/eh/invitations/{invitation_id}/accept",
|
||||
headers=second_examiner_headers,
|
||||
json={"encrypted_passphrase": "encrypted-key-for-zk"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 4. ZK can now see EH in shared list
|
||||
response = client.get("/api/v1/eh/shared-with-me", headers=second_examiner_headers)
|
||||
assert response.status_code == 200
|
||||
shared = response.json()
|
||||
assert len(shared) == 1
|
||||
assert shared[0]["eh"]["id"] == sample_eh.id
|
||||
|
||||
# 5. EK sees invitation as accepted
|
||||
response = client.get("/api/v1/eh/invitations/sent", headers=auth_headers)
|
||||
assert response.status_code == 200
|
||||
sent = response.json()
|
||||
assert len(sent) == 1
|
||||
assert sent[0]["invitation"]["status"] == "accepted"
|
||||
|
||||
def test_invite_decline_reinvite_workflow(
|
||||
self, client, auth_headers, second_examiner_headers, sample_eh
|
||||
):
|
||||
"""Test workflow: invite -> decline -> re-invite."""
|
||||
# 1. Owner invites ZK
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/invite",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"invitee_email": "examiner2@school.de",
|
||||
"role": "second_examiner"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
invitation_id = response.json()["invitation_id"]
|
||||
|
||||
# 2. ZK declines
|
||||
response = client.post(
|
||||
f"/api/v1/eh/invitations/{invitation_id}/decline",
|
||||
headers=second_examiner_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 3. Owner can send new invitation (declined invitation doesn't block)
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/invite",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"invitee_email": "examiner2@school.de",
|
||||
"role": "second_examiner",
|
||||
"message": "Zweiter Versuch"
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200 # New invitation allowed
|
||||
|
||||
|
||||
# =============================================
|
||||
# INTEGRATION TESTS
|
||||
# =============================================
|
||||
|
||||
class TestEHWorkflow:
|
||||
"""Integration tests for complete EH workflow."""
|
||||
|
||||
def test_complete_sharing_workflow(
|
||||
self, client, auth_headers, second_examiner_headers, sample_eh, sample_klausur
|
||||
):
|
||||
"""Test complete workflow: upload -> link -> share -> access."""
|
||||
# 1. Link EH to Klausur
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/link-klausur",
|
||||
headers=auth_headers,
|
||||
json={"klausur_id": sample_klausur.id}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 2. Share with second examiner
|
||||
response = client.post(
|
||||
f"/api/v1/eh/{sample_eh.id}/share",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"user_id": "test-examiner-002",
|
||||
"role": "second_examiner",
|
||||
"encrypted_passphrase": "encrypted-secret",
|
||||
"klausur_id": sample_klausur.id
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# 3. Second examiner can see shared EH
|
||||
response = client.get("/api/v1/eh/shared-with-me", headers=second_examiner_headers)
|
||||
assert response.status_code == 200
|
||||
shared = response.json()
|
||||
assert len(shared) == 1
|
||||
assert shared[0]["eh"]["id"] == sample_eh.id
|
||||
|
||||
# 4. Second examiner can see linked EH for Klausur
|
||||
response = client.get(
|
||||
f"/api/v1/klausuren/{sample_klausur.id}/linked-eh",
|
||||
headers=second_examiner_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
linked = response.json()
|
||||
assert len(linked) == 1
|
||||
assert linked[0]["is_owner"] is False
|
||||
assert linked[0]["share"] is not None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
623
klausur-service/backend/tests/test_legal_templates.py
Normal file
623
klausur-service/backend/tests/test_legal_templates.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
Tests for Legal Templates RAG System.
|
||||
|
||||
Tests template_sources.py, github_crawler.py, legal_templates_ingestion.py,
|
||||
and the admin API endpoints for legal templates.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Template Sources Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestLicenseType:
|
||||
"""Tests for LicenseType enum."""
|
||||
|
||||
def test_license_types_exist(self):
|
||||
"""Test that all expected license types are defined."""
|
||||
from template_sources import LicenseType
|
||||
|
||||
assert LicenseType.PUBLIC_DOMAIN.value == "public_domain"
|
||||
assert LicenseType.CC0.value == "cc0"
|
||||
assert LicenseType.UNLICENSE.value == "unlicense"
|
||||
assert LicenseType.MIT.value == "mit"
|
||||
assert LicenseType.CC_BY_4.value == "cc_by_4"
|
||||
assert LicenseType.REUSE_NOTICE.value == "reuse_notice"
|
||||
|
||||
|
||||
class TestLicenseInfo:
|
||||
"""Tests for LicenseInfo dataclass."""
|
||||
|
||||
def test_license_info_creation(self):
|
||||
"""Test creating a LicenseInfo instance."""
|
||||
from template_sources import LicenseInfo, LicenseType
|
||||
|
||||
info = LicenseInfo(
|
||||
id=LicenseType.CC0,
|
||||
name="CC0 1.0 Universal",
|
||||
url="https://creativecommons.org/publicdomain/zero/1.0/",
|
||||
attribution_required=False,
|
||||
)
|
||||
|
||||
assert info.id == LicenseType.CC0
|
||||
assert info.attribution_required is False
|
||||
assert info.training_allowed is True
|
||||
assert info.output_allowed is True
|
||||
|
||||
def test_get_attribution_text_no_attribution(self):
|
||||
"""Test attribution text when not required."""
|
||||
from template_sources import LicenseInfo, LicenseType
|
||||
|
||||
info = LicenseInfo(
|
||||
id=LicenseType.CC0,
|
||||
name="CC0",
|
||||
url="https://example.com",
|
||||
attribution_required=False,
|
||||
)
|
||||
|
||||
result = info.get_attribution_text("Test Source", "https://test.com")
|
||||
assert result == ""
|
||||
|
||||
def test_get_attribution_text_with_template(self):
|
||||
"""Test attribution text with template."""
|
||||
from template_sources import LicenseInfo, LicenseType
|
||||
|
||||
info = LicenseInfo(
|
||||
id=LicenseType.MIT,
|
||||
name="MIT License",
|
||||
url="https://opensource.org/licenses/MIT",
|
||||
attribution_required=True,
|
||||
attribution_template="Based on [{source_name}]({source_url}) - MIT License",
|
||||
)
|
||||
|
||||
result = info.get_attribution_text("Test Source", "https://test.com")
|
||||
assert "Test Source" in result
|
||||
assert "https://test.com" in result
|
||||
|
||||
|
||||
class TestSourceConfig:
|
||||
"""Tests for SourceConfig dataclass."""
|
||||
|
||||
def test_source_config_creation(self):
|
||||
"""Test creating a SourceConfig instance."""
|
||||
from template_sources import SourceConfig, LicenseType
|
||||
|
||||
source = SourceConfig(
|
||||
name="test-source",
|
||||
license_type=LicenseType.CC0,
|
||||
template_types=["privacy_policy", "terms_of_service"],
|
||||
languages=["de", "en"],
|
||||
jurisdiction="DE",
|
||||
description="Test description",
|
||||
repo_url="https://github.com/test/repo",
|
||||
)
|
||||
|
||||
assert source.name == "test-source"
|
||||
assert source.license_type == LicenseType.CC0
|
||||
assert "privacy_policy" in source.template_types
|
||||
assert source.enabled is True
|
||||
|
||||
def test_source_config_license_info(self):
|
||||
"""Test getting license info from source config."""
|
||||
from template_sources import SourceConfig, LicenseType, LICENSES
|
||||
|
||||
source = SourceConfig(
|
||||
name="test-source",
|
||||
license_type=LicenseType.MIT,
|
||||
template_types=["privacy_policy"],
|
||||
languages=["en"],
|
||||
jurisdiction="US",
|
||||
description="Test",
|
||||
)
|
||||
|
||||
info = source.license_info
|
||||
assert info.id == LicenseType.MIT
|
||||
assert info.attribution_required is True
|
||||
|
||||
|
||||
class TestTemplateSources:
|
||||
"""Tests for TEMPLATE_SOURCES list."""
|
||||
|
||||
def test_template_sources_not_empty(self):
|
||||
"""Test that template sources are defined."""
|
||||
from template_sources import TEMPLATE_SOURCES
|
||||
|
||||
assert len(TEMPLATE_SOURCES) > 0
|
||||
|
||||
def test_github_site_policy_exists(self):
|
||||
"""Test that github-site-policy source exists."""
|
||||
from template_sources import TEMPLATE_SOURCES
|
||||
|
||||
source = next((s for s in TEMPLATE_SOURCES if s.name == "github-site-policy"), None)
|
||||
assert source is not None
|
||||
assert source.repo_url == "https://github.com/github/site-policy"
|
||||
|
||||
def test_enabled_sources(self):
|
||||
"""Test getting enabled sources."""
|
||||
from template_sources import get_enabled_sources
|
||||
|
||||
enabled = get_enabled_sources()
|
||||
assert all(s.enabled for s in enabled)
|
||||
|
||||
def test_sources_by_priority(self):
|
||||
"""Test getting sources by priority."""
|
||||
from template_sources import get_sources_by_priority
|
||||
|
||||
# Priority 1 sources only
|
||||
p1 = get_sources_by_priority(1)
|
||||
assert all(s.priority == 1 for s in p1)
|
||||
|
||||
# Priority 1-2 sources
|
||||
p2 = get_sources_by_priority(2)
|
||||
assert all(s.priority <= 2 for s in p2)
|
||||
|
||||
def test_sources_by_license(self):
|
||||
"""Test getting sources by license type."""
|
||||
from template_sources import get_sources_by_license, LicenseType
|
||||
|
||||
cc0_sources = get_sources_by_license(LicenseType.CC0)
|
||||
assert all(s.license_type == LicenseType.CC0 for s in cc0_sources)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GitHub Crawler Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestMarkdownParser:
|
||||
"""Tests for MarkdownParser class."""
|
||||
|
||||
def test_parse_simple_markdown(self):
|
||||
"""Test parsing simple markdown content."""
|
||||
from github_crawler import MarkdownParser
|
||||
|
||||
content = """# Test Title
|
||||
|
||||
This is some content.
|
||||
|
||||
## Section 1
|
||||
|
||||
More content here.
|
||||
"""
|
||||
doc = MarkdownParser.parse(content, "test.md")
|
||||
|
||||
assert doc.title == "Test Title"
|
||||
assert doc.file_type == "markdown"
|
||||
assert "content" in doc.text
|
||||
|
||||
def test_extract_title_from_heading(self):
|
||||
"""Test extracting title from h1 heading."""
|
||||
from github_crawler import MarkdownParser
|
||||
|
||||
title = MarkdownParser._extract_title("# My Document\n\nContent", "fallback.md")
|
||||
assert title == "My Document"
|
||||
|
||||
def test_extract_title_fallback(self):
|
||||
"""Test fallback to filename when no heading."""
|
||||
from github_crawler import MarkdownParser
|
||||
|
||||
title = MarkdownParser._extract_title("No heading here", "my-document.md")
|
||||
assert title == "My Document"
|
||||
|
||||
def test_detect_german_language(self):
|
||||
"""Test German language detection."""
|
||||
from github_crawler import MarkdownParser
|
||||
|
||||
german_text = "Dies ist eine Datenschutzerklaerung fuer die Verarbeitung personenbezogener Daten."
|
||||
lang = MarkdownParser._detect_language(german_text)
|
||||
assert lang == "de"
|
||||
|
||||
def test_detect_english_language(self):
|
||||
"""Test English language detection."""
|
||||
from github_crawler import MarkdownParser
|
||||
|
||||
english_text = "This is a privacy policy for processing personal data in our application."
|
||||
lang = MarkdownParser._detect_language(english_text)
|
||||
assert lang == "en"
|
||||
|
||||
def test_find_placeholders(self):
|
||||
"""Test finding placeholder patterns."""
|
||||
from github_crawler import MarkdownParser
|
||||
|
||||
content = "Company: [COMPANY_NAME], Contact: {email}, Address: __ADDRESS__"
|
||||
placeholders = MarkdownParser._find_placeholders(content)
|
||||
|
||||
assert "[COMPANY_NAME]" in placeholders
|
||||
assert "{email}" in placeholders
|
||||
assert "__ADDRESS__" in placeholders
|
||||
|
||||
|
||||
class TestHTMLParser:
|
||||
"""Tests for HTMLParser class."""
|
||||
|
||||
def test_parse_simple_html(self):
|
||||
"""Test parsing simple HTML content."""
|
||||
from github_crawler import HTMLParser
|
||||
|
||||
content = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Test Page</title></head>
|
||||
<body>
|
||||
<h1>Welcome</h1>
|
||||
<p>This is content.</p>
|
||||
</body>
|
||||
</html>"""
|
||||
doc = HTMLParser.parse(content, "test.html")
|
||||
|
||||
assert doc.title == "Test Page"
|
||||
assert doc.file_type == "html"
|
||||
assert "Welcome" in doc.text
|
||||
assert "content" in doc.text
|
||||
|
||||
def test_html_to_text_removes_scripts(self):
|
||||
"""Test that scripts are removed from HTML."""
|
||||
from github_crawler import HTMLParser
|
||||
|
||||
html = "<p>Text</p><script>alert('bad');</script><p>More</p>"
|
||||
text = HTMLParser._html_to_text(html)
|
||||
|
||||
assert "alert" not in text
|
||||
assert "Text" in text
|
||||
assert "More" in text
|
||||
|
||||
|
||||
class TestJSONParser:
|
||||
"""Tests for JSONParser class."""
|
||||
|
||||
def test_parse_simple_json(self):
|
||||
"""Test parsing simple JSON content."""
|
||||
from github_crawler import JSONParser
|
||||
|
||||
content = json.dumps({
|
||||
"title": "Privacy Policy",
|
||||
"text": "This is the privacy policy content.",
|
||||
"language": "en",
|
||||
})
|
||||
|
||||
docs = JSONParser.parse(content, "policy.json")
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].title == "Privacy Policy"
|
||||
assert "privacy policy content" in docs[0].text
|
||||
|
||||
def test_parse_nested_json(self):
|
||||
"""Test parsing nested JSON structures."""
|
||||
from github_crawler import JSONParser
|
||||
|
||||
content = json.dumps({
|
||||
"sections": {
|
||||
"intro": {"title": "Introduction", "text": "Welcome text"},
|
||||
"data": {"title": "Data Collection", "text": "Collection info"},
|
||||
}
|
||||
})
|
||||
|
||||
docs = JSONParser.parse(content, "nested.json")
|
||||
# Should extract nested documents
|
||||
assert len(docs) >= 2
|
||||
|
||||
|
||||
class TestExtractedDocument:
|
||||
"""Tests for ExtractedDocument dataclass."""
|
||||
|
||||
def test_extracted_document_hash(self):
|
||||
"""Test that source hash is auto-generated."""
|
||||
from github_crawler import ExtractedDocument
|
||||
|
||||
doc = ExtractedDocument(
|
||||
text="Some content",
|
||||
title="Test",
|
||||
file_path="test.md",
|
||||
file_type="markdown",
|
||||
source_url="https://example.com",
|
||||
)
|
||||
|
||||
assert doc.source_hash != ""
|
||||
assert len(doc.source_hash) == 64 # SHA256 hex
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Legal Templates Ingestion Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestLegalTemplatesIngestion:
|
||||
"""Tests for LegalTemplatesIngestion class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant(self):
|
||||
"""Mock Qdrant client."""
|
||||
with patch('legal_templates_ingestion.QdrantClient') as mock:
|
||||
client = MagicMock()
|
||||
client.get_collections.return_value.collections = []
|
||||
mock.return_value = client
|
||||
yield client
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client(self):
|
||||
"""Mock HTTP client for embeddings."""
|
||||
with patch('legal_templates_ingestion.httpx.AsyncClient') as mock:
|
||||
client = AsyncMock()
|
||||
mock.return_value = client
|
||||
yield client
|
||||
|
||||
def test_chunk_text_short(self):
|
||||
"""Test chunking short text."""
|
||||
from legal_templates_ingestion import LegalTemplatesIngestion
|
||||
|
||||
with patch('legal_templates_ingestion.QdrantClient'):
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
chunks = ingestion._chunk_text("Short text", chunk_size=1000)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == "Short text"
|
||||
|
||||
def test_chunk_text_long(self):
|
||||
"""Test chunking long text."""
|
||||
from legal_templates_ingestion import LegalTemplatesIngestion
|
||||
|
||||
with patch('legal_templates_ingestion.QdrantClient'):
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
|
||||
# Create text longer than chunk size
|
||||
long_text = "This is a sentence. " * 100
|
||||
chunks = ingestion._chunk_text(long_text, chunk_size=200, overlap=50)
|
||||
|
||||
assert len(chunks) > 1
|
||||
# Each chunk should be roughly chunk_size
|
||||
for chunk in chunks:
|
||||
assert len(chunk) <= 250 # Allow some buffer
|
||||
|
||||
def test_split_sentences(self):
|
||||
"""Test German sentence splitting."""
|
||||
from legal_templates_ingestion import LegalTemplatesIngestion
|
||||
|
||||
with patch('legal_templates_ingestion.QdrantClient'):
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
text = "Dies ist Satz eins. Dies ist Satz zwei. Und Satz drei."
|
||||
sentences = ingestion._split_sentences(text)
|
||||
|
||||
assert len(sentences) == 3
|
||||
|
||||
def test_split_sentences_preserves_abbreviations(self):
|
||||
"""Test that abbreviations don't split sentences."""
|
||||
from legal_templates_ingestion import LegalTemplatesIngestion
|
||||
|
||||
with patch('legal_templates_ingestion.QdrantClient'):
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
text = "Das ist z.B. ein Beispiel. Und noch ein Satz."
|
||||
sentences = ingestion._split_sentences(text)
|
||||
|
||||
assert len(sentences) == 2
|
||||
assert "z.B." in sentences[0] or "z.b." in sentences[0].lower()
|
||||
|
||||
def test_infer_template_type_privacy(self):
|
||||
"""Test inferring privacy policy type."""
|
||||
from legal_templates_ingestion import LegalTemplatesIngestion
|
||||
from github_crawler import ExtractedDocument
|
||||
from template_sources import SourceConfig, LicenseType
|
||||
|
||||
with patch('legal_templates_ingestion.QdrantClient'):
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
|
||||
doc = ExtractedDocument(
|
||||
text="Diese Datenschutzerklaerung informiert Sie ueber die Verarbeitung personenbezogener Daten.",
|
||||
title="Datenschutz",
|
||||
file_path="privacy.md",
|
||||
file_type="markdown",
|
||||
source_url="https://example.com",
|
||||
)
|
||||
|
||||
source = SourceConfig(
|
||||
name="test",
|
||||
license_type=LicenseType.CC0,
|
||||
template_types=["privacy_policy"],
|
||||
languages=["de"],
|
||||
jurisdiction="DE",
|
||||
description="Test",
|
||||
)
|
||||
|
||||
template_type = ingestion._infer_template_type(doc, source)
|
||||
assert template_type == "privacy_policy"
|
||||
|
||||
def test_infer_clause_category(self):
|
||||
"""Test inferring clause category."""
|
||||
from legal_templates_ingestion import LegalTemplatesIngestion
|
||||
|
||||
with patch('legal_templates_ingestion.QdrantClient'):
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
|
||||
# Test liability clause
|
||||
text = "Die Haftung des Anbieters ist auf grobe Fahrlässigkeit beschränkt."
|
||||
category = ingestion._infer_clause_category(text)
|
||||
assert category == "haftung"
|
||||
|
||||
# Test privacy clause
|
||||
text = "Wir verarbeiten personenbezogene Daten gemäß der DSGVO."
|
||||
category = ingestion._infer_clause_category(text)
|
||||
assert category == "datenschutz"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Admin API Templates Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTemplatesAdminAPI:
|
||||
"""Tests for /api/v1/admin/templates/* endpoints."""
|
||||
|
||||
def test_templates_status_structure(self):
|
||||
"""Test the structure of templates status response."""
|
||||
from admin_api import _templates_ingestion_status
|
||||
|
||||
# Reset status
|
||||
_templates_ingestion_status["running"] = False
|
||||
_templates_ingestion_status["last_run"] = None
|
||||
_templates_ingestion_status["current_source"] = None
|
||||
_templates_ingestion_status["results"] = {}
|
||||
|
||||
assert _templates_ingestion_status["running"] is False
|
||||
assert _templates_ingestion_status["results"] == {}
|
||||
|
||||
def test_templates_status_running(self):
|
||||
"""Test status when ingestion is running."""
|
||||
from admin_api import _templates_ingestion_status
|
||||
|
||||
_templates_ingestion_status["running"] = True
|
||||
_templates_ingestion_status["current_source"] = "github-site-policy"
|
||||
_templates_ingestion_status["last_run"] = datetime.now().isoformat()
|
||||
|
||||
assert _templates_ingestion_status["running"] is True
|
||||
assert _templates_ingestion_status["current_source"] == "github-site-policy"
|
||||
|
||||
def test_templates_results_tracking(self):
|
||||
"""Test that ingestion results are tracked correctly."""
|
||||
from admin_api import _templates_ingestion_status
|
||||
|
||||
_templates_ingestion_status["results"] = {
|
||||
"github-site-policy": {
|
||||
"status": "completed",
|
||||
"documents_found": 15,
|
||||
"chunks_indexed": 42,
|
||||
"errors": [],
|
||||
},
|
||||
"opr-vc": {
|
||||
"status": "failed",
|
||||
"documents_found": 0,
|
||||
"chunks_indexed": 0,
|
||||
"errors": ["Connection timeout"],
|
||||
},
|
||||
}
|
||||
|
||||
results = _templates_ingestion_status["results"]
|
||||
assert results["github-site-policy"]["status"] == "completed"
|
||||
assert results["github-site-policy"]["chunks_indexed"] == 42
|
||||
assert results["opr-vc"]["status"] == "failed"
|
||||
assert len(results["opr-vc"]["errors"]) > 0
|
||||
|
||||
|
||||
class TestTemplateTypeLabels:
|
||||
"""Tests for template type labels and constants."""
|
||||
|
||||
def test_template_types_defined(self):
|
||||
"""Test that template types are properly defined."""
|
||||
from template_sources import TEMPLATE_TYPES
|
||||
|
||||
assert "privacy_policy" in TEMPLATE_TYPES
|
||||
assert "terms_of_service" in TEMPLATE_TYPES
|
||||
assert "cookie_banner" in TEMPLATE_TYPES
|
||||
assert "impressum" in TEMPLATE_TYPES
|
||||
assert "widerruf" in TEMPLATE_TYPES
|
||||
assert "dpa" in TEMPLATE_TYPES
|
||||
|
||||
def test_jurisdictions_defined(self):
|
||||
"""Test that jurisdictions are properly defined."""
|
||||
from template_sources import JURISDICTIONS
|
||||
|
||||
assert "DE" in JURISDICTIONS
|
||||
assert "AT" in JURISDICTIONS
|
||||
assert "CH" in JURISDICTIONS
|
||||
assert "EU" in JURISDICTIONS
|
||||
assert "US" in JURISDICTIONS
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Qdrant Service Templates Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestQdrantServiceTemplates:
|
||||
"""Tests for legal templates Qdrant service functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client(self):
|
||||
"""Mock Qdrant client for templates."""
|
||||
with patch('qdrant_service.get_qdrant_client') as mock:
|
||||
client = MagicMock()
|
||||
client.get_collections.return_value.collections = []
|
||||
mock.return_value = client
|
||||
yield client
|
||||
|
||||
def test_legal_templates_collection_name(self):
|
||||
"""Test that collection name is correct."""
|
||||
from qdrant_service import LEGAL_TEMPLATES_COLLECTION
|
||||
|
||||
assert LEGAL_TEMPLATES_COLLECTION == "bp_legal_templates"
|
||||
|
||||
def test_legal_templates_vector_size(self):
|
||||
"""Test that vector size is correct for BGE-M3."""
|
||||
from qdrant_service import LEGAL_TEMPLATES_VECTOR_SIZE
|
||||
|
||||
assert LEGAL_TEMPLATES_VECTOR_SIZE == 1024
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests (require mocking external services)
|
||||
# =============================================================================
|
||||
|
||||
class TestTemplatesIntegration:
|
||||
"""Integration tests for the templates system."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_all_services(self):
|
||||
"""Mock all external services."""
|
||||
with patch('legal_templates_ingestion.QdrantClient') as qdrant_mock, \
|
||||
patch('legal_templates_ingestion.httpx.AsyncClient') as http_mock:
|
||||
|
||||
qdrant = MagicMock()
|
||||
qdrant.get_collections.return_value.collections = []
|
||||
qdrant_mock.return_value = qdrant
|
||||
|
||||
http = AsyncMock()
|
||||
http.post.return_value.json.return_value = {"embeddings": [[0.1] * 1024]}
|
||||
http.post.return_value.raise_for_status = MagicMock()
|
||||
http_mock.return_value.__aenter__.return_value = http
|
||||
|
||||
yield {"qdrant": qdrant, "http": http}
|
||||
|
||||
def test_full_chunk_creation_pipeline(self, mock_all_services):
|
||||
"""Test the full chunk creation pipeline."""
|
||||
from legal_templates_ingestion import LegalTemplatesIngestion
|
||||
from github_crawler import ExtractedDocument
|
||||
from template_sources import SourceConfig, LicenseType
|
||||
|
||||
ingestion = LegalTemplatesIngestion()
|
||||
|
||||
doc = ExtractedDocument(
|
||||
text="# Datenschutzerklaerung\n\nWir nehmen den Schutz Ihrer personenbezogenen Daten sehr ernst. Diese Datenschutzerklaerung informiert Sie ueber die Verarbeitung Ihrer Daten gemaess der DSGVO.",
|
||||
title="Datenschutzerklaerung",
|
||||
file_path="privacy.md",
|
||||
file_type="markdown",
|
||||
source_url="https://example.com/privacy.md",
|
||||
source_commit="abc123",
|
||||
placeholders=["[FIRMENNAME]"],
|
||||
language="de", # Explicitly set language
|
||||
)
|
||||
|
||||
source = SourceConfig(
|
||||
name="test-source",
|
||||
license_type=LicenseType.CC0,
|
||||
template_types=["privacy_policy"],
|
||||
languages=["de"],
|
||||
jurisdiction="DE",
|
||||
description="Test source",
|
||||
repo_url="https://github.com/test/repo",
|
||||
)
|
||||
|
||||
chunks = ingestion._create_chunks(doc, source)
|
||||
|
||||
assert len(chunks) >= 1
|
||||
assert chunks[0].template_type == "privacy_policy"
|
||||
assert chunks[0].language == "de"
|
||||
assert chunks[0].jurisdiction == "DE"
|
||||
assert chunks[0].license_id == "cc0"
|
||||
assert chunks[0].attribution_required is False
|
||||
assert "[FIRMENNAME]" in chunks[0].placeholders
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Runner Configuration
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
349
klausur-service/backend/tests/test_mail_service.py
Normal file
349
klausur-service/backend/tests/test_mail_service.py
Normal file
@@ -0,0 +1,349 @@
|
||||
"""
|
||||
Unit Tests for Mail Module
|
||||
|
||||
Tests for:
|
||||
- TaskService: Priority calculation, deadline handling
|
||||
- AIEmailService: Sender classification, deadline extraction
|
||||
- Models: Validation, known authorities
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
# Import the modules to test
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from mail.models import (
|
||||
TaskPriority,
|
||||
TaskStatus,
|
||||
SenderType,
|
||||
EmailCategory,
|
||||
KNOWN_AUTHORITIES_NI,
|
||||
DeadlineExtraction,
|
||||
EmailAccountCreate,
|
||||
TaskCreate,
|
||||
classify_sender_by_domain,
|
||||
)
|
||||
from mail.task_service import TaskService
|
||||
from mail.ai_service import AIEmailService
|
||||
|
||||
|
||||
class TestKnownAuthoritiesNI:
|
||||
"""Tests for Niedersachsen authority domain matching."""
|
||||
|
||||
def test_kultusministerium_domain(self):
|
||||
"""Test that MK Niedersachsen domain is recognized."""
|
||||
assert "@mk.niedersachsen.de" in KNOWN_AUTHORITIES_NI
|
||||
assert KNOWN_AUTHORITIES_NI["@mk.niedersachsen.de"]["type"] == SenderType.KULTUSMINISTERIUM
|
||||
|
||||
def test_rlsb_domain(self):
|
||||
"""Test that RLSB domain is recognized."""
|
||||
assert "@rlsb.de" in KNOWN_AUTHORITIES_NI
|
||||
assert KNOWN_AUTHORITIES_NI["@rlsb.de"]["type"] == SenderType.RLSB
|
||||
|
||||
def test_landesschulbehoerde_domain(self):
|
||||
"""Test that Landesschulbehörde domain is recognized."""
|
||||
assert "@landesschulbehoerde-nds.de" in KNOWN_AUTHORITIES_NI
|
||||
assert KNOWN_AUTHORITIES_NI["@landesschulbehoerde-nds.de"]["type"] == SenderType.LANDESSCHULBEHOERDE
|
||||
|
||||
def test_nibis_domain(self):
|
||||
"""Test that NiBiS domain is recognized."""
|
||||
assert "@nibis.de" in KNOWN_AUTHORITIES_NI
|
||||
assert KNOWN_AUTHORITIES_NI["@nibis.de"]["type"] == SenderType.NIBIS
|
||||
|
||||
def test_unknown_domain_not_in_list(self):
|
||||
"""Test that unknown domains are not in the list."""
|
||||
assert "@gmail.com" not in KNOWN_AUTHORITIES_NI
|
||||
assert "@example.de" not in KNOWN_AUTHORITIES_NI
|
||||
|
||||
|
||||
class TestTaskServicePriority:
|
||||
"""Tests for TaskService priority calculation."""
|
||||
|
||||
@pytest.fixture
|
||||
def task_service(self):
|
||||
return TaskService()
|
||||
|
||||
def test_priority_from_kultusministerium(self, task_service):
|
||||
"""Kultusministerium should result in HIGH priority."""
|
||||
priority = task_service._get_priority_from_sender(SenderType.KULTUSMINISTERIUM)
|
||||
assert priority == TaskPriority.HIGH
|
||||
|
||||
def test_priority_from_rlsb(self, task_service):
|
||||
"""RLSB should result in HIGH priority."""
|
||||
priority = task_service._get_priority_from_sender(SenderType.RLSB)
|
||||
assert priority == TaskPriority.HIGH
|
||||
|
||||
def test_priority_from_nibis(self, task_service):
|
||||
"""NiBiS should result in MEDIUM priority."""
|
||||
priority = task_service._get_priority_from_sender(SenderType.NIBIS)
|
||||
assert priority == TaskPriority.MEDIUM
|
||||
|
||||
def test_priority_from_privatperson(self, task_service):
|
||||
"""Privatperson should result in LOW priority."""
|
||||
priority = task_service._get_priority_from_sender(SenderType.PRIVATPERSON)
|
||||
assert priority == TaskPriority.LOW
|
||||
|
||||
|
||||
class TestTaskServiceDeadlineAdjustment:
|
||||
"""Tests for TaskService deadline-based priority adjustment."""
|
||||
|
||||
@pytest.fixture
|
||||
def task_service(self):
|
||||
return TaskService()
|
||||
|
||||
def test_urgent_for_tomorrow(self, task_service):
|
||||
"""Deadline tomorrow should be URGENT."""
|
||||
deadline = datetime.now() + timedelta(days=1)
|
||||
priority = task_service._adjust_priority_for_deadline(TaskPriority.LOW, deadline)
|
||||
assert priority == TaskPriority.URGENT
|
||||
|
||||
def test_urgent_for_today(self, task_service):
|
||||
"""Deadline today should be URGENT."""
|
||||
deadline = datetime.now() + timedelta(hours=5)
|
||||
priority = task_service._adjust_priority_for_deadline(TaskPriority.LOW, deadline)
|
||||
assert priority == TaskPriority.URGENT
|
||||
|
||||
def test_high_for_3_days(self, task_service):
|
||||
"""Deadline in 3 days with HIGH input stays HIGH."""
|
||||
deadline = datetime.now() + timedelta(days=3)
|
||||
# Note: max() compares enum by value string, so we test with HIGH input
|
||||
priority = task_service._adjust_priority_for_deadline(TaskPriority.HIGH, deadline)
|
||||
assert priority == TaskPriority.HIGH
|
||||
|
||||
def test_medium_for_7_days(self, task_service):
|
||||
"""Deadline in 7 days should be at least MEDIUM."""
|
||||
deadline = datetime.now() + timedelta(days=7)
|
||||
priority = task_service._adjust_priority_for_deadline(TaskPriority.LOW, deadline)
|
||||
assert priority == TaskPriority.MEDIUM
|
||||
|
||||
def test_no_change_for_far_deadline(self, task_service):
|
||||
"""Deadline far in the future should not change priority."""
|
||||
deadline = datetime.now() + timedelta(days=30)
|
||||
priority = task_service._adjust_priority_for_deadline(TaskPriority.LOW, deadline)
|
||||
assert priority == TaskPriority.LOW
|
||||
|
||||
|
||||
class TestTaskServiceDescriptionBuilder:
|
||||
"""Tests for TaskService description building."""
|
||||
|
||||
@pytest.fixture
|
||||
def task_service(self):
|
||||
return TaskService()
|
||||
|
||||
def test_description_with_deadlines(self, task_service):
|
||||
"""Description should include deadline information."""
|
||||
deadlines = [
|
||||
DeadlineExtraction(
|
||||
deadline_date=datetime(2026, 1, 15),
|
||||
description="Einreichung der Unterlagen",
|
||||
is_firm=True,
|
||||
confidence=0.9,
|
||||
source_text="bis zum 15.01.2026",
|
||||
)
|
||||
]
|
||||
email_data = {
|
||||
"sender_email": "test@mk.niedersachsen.de",
|
||||
"body_preview": "Bitte reichen Sie die Unterlagen ein.",
|
||||
}
|
||||
|
||||
description = task_service._build_task_description(deadlines, email_data)
|
||||
|
||||
assert "**Fristen:**" in description
|
||||
assert "15.01.2026" in description
|
||||
assert "Einreichung der Unterlagen" in description
|
||||
assert "(verbindlich)" in description
|
||||
assert "test@mk.niedersachsen.de" in description
|
||||
|
||||
def test_description_without_deadlines(self, task_service):
|
||||
"""Description should work without deadlines."""
|
||||
email_data = {
|
||||
"sender_email": "sender@example.de",
|
||||
"body_preview": "Test preview text",
|
||||
}
|
||||
|
||||
description = task_service._build_task_description([], email_data)
|
||||
|
||||
assert "**Fristen:**" not in description
|
||||
assert "sender@example.de" in description
|
||||
|
||||
|
||||
class TestSenderClassification:
|
||||
"""Tests for sender classification via classify_sender_by_domain."""
|
||||
|
||||
def test_classify_kultusministerium(self):
|
||||
"""Email from MK should be classified correctly."""
|
||||
result = classify_sender_by_domain("referat@mk.niedersachsen.de")
|
||||
assert result is not None
|
||||
assert result.sender_type == SenderType.KULTUSMINISTERIUM
|
||||
|
||||
def test_classify_rlsb(self):
|
||||
"""Email from RLSB should be classified correctly."""
|
||||
result = classify_sender_by_domain("info@rlsb.de")
|
||||
assert result is not None
|
||||
assert result.sender_type == SenderType.RLSB
|
||||
|
||||
def test_classify_unknown_domain(self):
|
||||
"""Email from unknown domain should return None."""
|
||||
result = classify_sender_by_domain("user@gmail.com")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAIEmailServiceDeadlineExtraction:
|
||||
"""Tests for AIEmailService deadline extraction from text."""
|
||||
|
||||
@pytest.fixture
|
||||
def ai_service(self):
|
||||
return AIEmailService()
|
||||
|
||||
def test_extract_deadline_bis_format(self, ai_service):
|
||||
"""Test extraction of 'bis zum DD.MM.YYYY' format."""
|
||||
text = "Bitte senden Sie die Unterlagen bis zum 15.01.2027 ein."
|
||||
deadlines = ai_service._extract_deadlines_regex(text)
|
||||
|
||||
assert len(deadlines) >= 1
|
||||
# Check that at least one deadline was found
|
||||
dates = [d.deadline_date.strftime("%Y-%m-%d") for d in deadlines]
|
||||
assert "2027-01-15" in dates
|
||||
|
||||
def test_extract_deadline_frist_format(self, ai_service):
|
||||
"""Test extraction of 'Frist: DD.MM.YYYY' format."""
|
||||
text = "Die Frist: 20.02.2027 muss eingehalten werden."
|
||||
deadlines = ai_service._extract_deadlines_regex(text)
|
||||
|
||||
assert len(deadlines) >= 1
|
||||
dates = [d.deadline_date.strftime("%Y-%m-%d") for d in deadlines]
|
||||
assert "2027-02-20" in dates
|
||||
|
||||
def test_no_deadline_in_text(self, ai_service):
|
||||
"""Test that no deadlines are found when none exist."""
|
||||
text = "Dies ist eine allgemeine Mitteilung ohne Datum."
|
||||
deadlines = ai_service._extract_deadlines_regex(text)
|
||||
|
||||
assert len(deadlines) == 0
|
||||
|
||||
|
||||
class TestAIEmailServiceCategoryRules:
|
||||
"""Tests for AIEmailService category classification rules."""
|
||||
|
||||
@pytest.fixture
|
||||
def ai_service(self):
|
||||
return AIEmailService()
|
||||
|
||||
def test_fortbildung_category(self, ai_service):
|
||||
"""Test Fortbildung category detection."""
|
||||
# Use keywords that clearly match FORTBILDUNG: fortbildung, seminar, workshop
|
||||
subject = "Fortbildung NLQ Seminar"
|
||||
body = "Wir bieten eine Weiterbildung zum Thema Didaktik an."
|
||||
|
||||
category, confidence = ai_service._classify_category_rules(subject, body, SenderType.UNBEKANNT)
|
||||
assert category == EmailCategory.FORTBILDUNG
|
||||
|
||||
def test_personal_category(self, ai_service):
|
||||
"""Test Personal category detection."""
|
||||
# Use keywords that clearly match PERSONAL: personalrat, versetzung, krankmeldung
|
||||
subject = "Personalrat Sitzung"
|
||||
body = "Thema: Krankmeldung und Beurteilung"
|
||||
|
||||
category, confidence = ai_service._classify_category_rules(subject, body, SenderType.UNBEKANNT)
|
||||
assert category == EmailCategory.PERSONAL
|
||||
|
||||
def test_finanzen_category(self, ai_service):
|
||||
"""Test Finanzen category detection."""
|
||||
# Use keywords that clearly match FINANZEN: budget, haushalt, abrechnung
|
||||
subject = "Haushalt 2026 Budget"
|
||||
body = "Die Abrechnung und Erstattung für das neue Etat."
|
||||
|
||||
category, confidence = ai_service._classify_category_rules(subject, body, SenderType.UNBEKANNT)
|
||||
assert category == EmailCategory.FINANZEN
|
||||
|
||||
|
||||
class TestEmailAccountCreateValidation:
|
||||
"""Tests for EmailAccountCreate Pydantic model validation."""
|
||||
|
||||
def test_valid_account_creation(self):
|
||||
"""Test that valid data creates an account."""
|
||||
account = EmailAccountCreate(
|
||||
email="schulleitung@grundschule-xy.de",
|
||||
display_name="Schulleitung",
|
||||
imap_host="imap.example.com",
|
||||
imap_port=993,
|
||||
smtp_host="smtp.example.com",
|
||||
smtp_port=587,
|
||||
password="secret123",
|
||||
)
|
||||
|
||||
assert account.email == "schulleitung@grundschule-xy.de"
|
||||
assert account.imap_port == 993
|
||||
assert account.imap_ssl is True # Default
|
||||
|
||||
def test_default_ssl_true(self):
|
||||
"""Test that SSL defaults to True."""
|
||||
account = EmailAccountCreate(
|
||||
email="test@example.com",
|
||||
display_name="Test Account",
|
||||
imap_host="imap.example.com",
|
||||
imap_port=993,
|
||||
smtp_host="smtp.example.com",
|
||||
smtp_port=587,
|
||||
password="secret",
|
||||
)
|
||||
|
||||
assert account.imap_ssl is True
|
||||
assert account.smtp_ssl is True
|
||||
|
||||
|
||||
class TestTaskCreateValidation:
|
||||
"""Tests for TaskCreate Pydantic model validation."""
|
||||
|
||||
def test_valid_task_creation(self):
|
||||
"""Test that valid data creates a task."""
|
||||
task = TaskCreate(
|
||||
title="Unterlagen einreichen",
|
||||
description="Bitte alle Dokumente bis Freitag.",
|
||||
priority=TaskPriority.HIGH,
|
||||
deadline=datetime(2026, 1, 15),
|
||||
)
|
||||
|
||||
assert task.title == "Unterlagen einreichen"
|
||||
assert task.priority == TaskPriority.HIGH
|
||||
|
||||
def test_default_priority_medium(self):
|
||||
"""Test that priority defaults to MEDIUM."""
|
||||
task = TaskCreate(
|
||||
title="Einfache Aufgabe",
|
||||
)
|
||||
|
||||
assert task.priority == TaskPriority.MEDIUM
|
||||
|
||||
def test_optional_deadline(self):
|
||||
"""Test that deadline is optional."""
|
||||
task = TaskCreate(
|
||||
title="Keine Frist",
|
||||
)
|
||||
|
||||
assert task.deadline is None
|
||||
|
||||
|
||||
# Integration test placeholder
|
||||
class TestMailModuleIntegration:
|
||||
"""Integration tests (require database connection)."""
|
||||
|
||||
@pytest.mark.skip(reason="Requires database connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_from_email(self):
|
||||
"""Test creating a task from an email analysis."""
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Requires database connection")
|
||||
@pytest.mark.asyncio
|
||||
async def test_dashboard_stats(self):
|
||||
"""Test dashboard statistics calculation."""
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
799
klausur-service/backend/tests/test_ocr_labeling.py
Normal file
799
klausur-service/backend/tests/test_ocr_labeling.py
Normal file
@@ -0,0 +1,799 @@
|
||||
"""
|
||||
Tests for OCR Labeling API
|
||||
Tests session management, image upload, labeling workflow, and training export.
|
||||
|
||||
BACKLOG: Feature not yet fully integrated - requires external OCR services
|
||||
See: https://macmini:3002/infrastructure/tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
# Mark all tests in this module as expected failures (backlog item)
|
||||
pytestmark = pytest.mark.xfail(
|
||||
reason="ocr_labeling requires external services not available in CI - Backlog item",
|
||||
strict=False # Don't fail if test unexpectedly passes
|
||||
)
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
import io
|
||||
import json
|
||||
import hashlib
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_pool():
|
||||
"""Mock PostgreSQL connection pool."""
|
||||
with patch('metrics_db.get_pool') as mock:
|
||||
pool = AsyncMock()
|
||||
conn = AsyncMock()
|
||||
pool.acquire.return_value.__aenter__.return_value = conn
|
||||
mock.return_value = pool
|
||||
yield pool, conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_minio():
|
||||
"""Mock MinIO storage functions."""
|
||||
with patch('ocr_labeling_api.MINIO_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.upload_ocr_image') as upload_mock, \
|
||||
patch('ocr_labeling_api.get_ocr_image') as get_mock:
|
||||
upload_mock.return_value = "ocr-labeling/session-123/item-456.png"
|
||||
get_mock.return_value = b"\x89PNG fake image data"
|
||||
yield upload_mock, get_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vision_ocr():
|
||||
"""Mock Vision OCR service."""
|
||||
with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.get_vision_ocr_service') as mock:
|
||||
service = AsyncMock()
|
||||
service.is_available.return_value = True
|
||||
result = MagicMock()
|
||||
result.text = "Erkannter Text aus dem Bild"
|
||||
result.confidence = 0.87
|
||||
service.extract_text.return_value = result
|
||||
mock.return_value = service
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_training_export():
|
||||
"""Mock training export service."""
|
||||
with patch('ocr_labeling_api.TRAINING_EXPORT_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.get_training_export_service') as mock:
|
||||
service = MagicMock()
|
||||
export_result = MagicMock()
|
||||
export_result.export_path = "/app/ocr-exports/generic/20260121_120000"
|
||||
export_result.manifest_path = "/app/ocr-exports/generic/20260121_120000/manifest.json"
|
||||
export_result.batch_id = "20260121_120000"
|
||||
service.export.return_value = export_result
|
||||
service.list_exports.return_value = [
|
||||
{"format": "generic", "batch_id": "20260121_120000", "sample_count": 10}
|
||||
]
|
||||
mock.return_value = service
|
||||
yield service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session():
|
||||
"""Sample session data."""
|
||||
return {
|
||||
"id": "session-123",
|
||||
"name": "Test Session",
|
||||
"source_type": "klausur",
|
||||
"description": "Test description",
|
||||
"ocr_model": "llama3.2-vision:11b",
|
||||
"total_items": 5,
|
||||
"labeled_items": 2,
|
||||
"confirmed_items": 1,
|
||||
"corrected_items": 1,
|
||||
"skipped_items": 0,
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_item():
|
||||
"""Sample labeling item data."""
|
||||
return {
|
||||
"id": "item-456",
|
||||
"session_id": "session-123",
|
||||
"session_name": "Test Session",
|
||||
"image_path": "/app/ocr-labeling/session-123/item-456.png",
|
||||
"ocr_text": "Erkannter Text",
|
||||
"ocr_confidence": 0.87,
|
||||
"ground_truth": None,
|
||||
"status": "pending",
|
||||
"metadata": {"page": 1},
|
||||
"created_at": datetime.utcnow(),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Management Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestSessionCreation:
|
||||
"""Tests for session creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session_success(self, mock_db_pool):
|
||||
"""Test successful session creation."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
from metrics_db import create_ocr_labeling_session
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.execute.return_value = None
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
result = await create_ocr_labeling_session(
|
||||
session_id="session-123",
|
||||
name="Test Session",
|
||||
source_type="klausur",
|
||||
description="Test",
|
||||
ocr_model="llama3.2-vision:11b",
|
||||
)
|
||||
|
||||
# Should call execute to insert
|
||||
assert pool.acquire.called
|
||||
|
||||
def test_session_create_model_validation(self):
|
||||
"""Test SessionCreate model validation."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
# Valid session
|
||||
session = SessionCreate(
|
||||
name="Test Session",
|
||||
source_type="klausur",
|
||||
description="Test description",
|
||||
)
|
||||
assert session.name == "Test Session"
|
||||
assert session.source_type == "klausur"
|
||||
assert session.ocr_model == "llama3.2-vision:11b" # default
|
||||
|
||||
def test_session_create_with_custom_model(self):
|
||||
"""Test SessionCreate with custom OCR model."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
session = SessionCreate(
|
||||
name="TrOCR Session",
|
||||
source_type="handwriting_sample",
|
||||
ocr_model="trocr-base",
|
||||
)
|
||||
assert session.ocr_model == "trocr-base"
|
||||
|
||||
|
||||
class TestSessionListing:
|
||||
"""Tests for session listing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sessions_empty(self):
|
||||
"""Test getting sessions when none exist."""
|
||||
from metrics_db import get_ocr_labeling_sessions
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
|
||||
sessions = await get_ocr_labeling_sessions()
|
||||
assert sessions == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_not_found(self):
|
||||
"""Test getting a non-existent session."""
|
||||
from metrics_db import get_ocr_labeling_session
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
|
||||
session = await get_ocr_labeling_session("non-existent-id")
|
||||
assert session is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Image Upload Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestImageUpload:
|
||||
"""Tests for image upload functionality."""
|
||||
|
||||
def test_compute_image_hash(self):
|
||||
"""Test image hash computation."""
|
||||
from ocr_labeling_api import compute_image_hash
|
||||
|
||||
image_data = b"\x89PNG fake image data"
|
||||
hash1 = compute_image_hash(image_data)
|
||||
hash2 = compute_image_hash(image_data)
|
||||
|
||||
# Same data should produce same hash
|
||||
assert hash1 == hash2
|
||||
assert len(hash1) == 64 # SHA256 hex length
|
||||
|
||||
def test_compute_image_hash_different_data(self):
|
||||
"""Test that different images produce different hashes."""
|
||||
from ocr_labeling_api import compute_image_hash
|
||||
|
||||
hash1 = compute_image_hash(b"image 1 data")
|
||||
hash2 = compute_image_hash(b"image 2 data")
|
||||
|
||||
assert hash1 != hash2
|
||||
|
||||
def test_save_image_locally(self, tmp_path):
|
||||
"""Test local image saving."""
|
||||
from ocr_labeling_api import save_image_locally, LOCAL_STORAGE_PATH
|
||||
|
||||
# Temporarily override storage path
|
||||
with patch('ocr_labeling_api.LOCAL_STORAGE_PATH', str(tmp_path)):
|
||||
from ocr_labeling_api import save_image_locally
|
||||
|
||||
image_data = b"\x89PNG fake image data"
|
||||
filepath = save_image_locally(
|
||||
session_id="session-123",
|
||||
item_id="item-456",
|
||||
image_data=image_data,
|
||||
extension="png",
|
||||
)
|
||||
|
||||
assert filepath.endswith("item-456.png")
|
||||
# File should exist
|
||||
import os
|
||||
assert os.path.exists(filepath)
|
||||
|
||||
def test_get_image_url_local(self):
|
||||
"""Test URL generation for local images."""
|
||||
from ocr_labeling_api import get_image_url, LOCAL_STORAGE_PATH
|
||||
|
||||
local_path = f"{LOCAL_STORAGE_PATH}/session-123/item-456.png"
|
||||
url = get_image_url(local_path)
|
||||
|
||||
assert url == "/api/v1/ocr-label/images/session-123/item-456.png"
|
||||
|
||||
def test_get_image_url_minio(self):
|
||||
"""Test URL for MinIO images (passthrough)."""
|
||||
from ocr_labeling_api import get_image_url
|
||||
|
||||
minio_path = "ocr-labeling/session-123/item-456.png"
|
||||
url = get_image_url(minio_path)
|
||||
|
||||
# Non-local paths are passed through
|
||||
assert url == minio_path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Labeling Workflow Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestConfirmLabel:
|
||||
"""Tests for label confirmation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confirm_label_success(self, mock_db_pool):
|
||||
"""Test successful label confirmation."""
|
||||
from metrics_db import confirm_ocr_label
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.fetchrow.return_value = {"ocr_text": "Test text"}
|
||||
conn.execute.return_value = None
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
result = await confirm_ocr_label(
|
||||
item_id="item-456",
|
||||
labeled_by="admin",
|
||||
label_time_seconds=5,
|
||||
)
|
||||
|
||||
# Should update item status and ground_truth
|
||||
assert conn.execute.called
|
||||
|
||||
def test_confirm_request_validation(self):
|
||||
"""Test ConfirmRequest model validation."""
|
||||
from ocr_labeling_api import ConfirmRequest
|
||||
|
||||
request = ConfirmRequest(
|
||||
item_id="item-456",
|
||||
label_time_seconds=5,
|
||||
)
|
||||
assert request.item_id == "item-456"
|
||||
assert request.label_time_seconds == 5
|
||||
|
||||
|
||||
class TestCorrectLabel:
|
||||
"""Tests for label correction."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_correct_label_success(self, mock_db_pool):
|
||||
"""Test successful label correction."""
|
||||
from metrics_db import correct_ocr_label
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.execute.return_value = None
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
result = await correct_ocr_label(
|
||||
item_id="item-456",
|
||||
ground_truth="Korrigierter Text",
|
||||
labeled_by="admin",
|
||||
label_time_seconds=15,
|
||||
)
|
||||
|
||||
# Should update item with corrected ground_truth
|
||||
assert conn.execute.called
|
||||
|
||||
def test_correct_request_validation(self):
|
||||
"""Test CorrectRequest model validation."""
|
||||
from ocr_labeling_api import CorrectRequest
|
||||
|
||||
request = CorrectRequest(
|
||||
item_id="item-456",
|
||||
ground_truth="Korrigierter Text",
|
||||
label_time_seconds=15,
|
||||
)
|
||||
assert request.item_id == "item-456"
|
||||
assert request.ground_truth == "Korrigierter Text"
|
||||
|
||||
|
||||
class TestSkipItem:
|
||||
"""Tests for item skipping."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_item_success(self, mock_db_pool):
|
||||
"""Test successful item skip."""
|
||||
from metrics_db import skip_ocr_item
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.execute.return_value = None
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
result = await skip_ocr_item(
|
||||
item_id="item-456",
|
||||
labeled_by="admin",
|
||||
)
|
||||
|
||||
# Should update item status to skipped
|
||||
assert conn.execute.called
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Statistics Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestLabelingStats:
|
||||
"""Tests for labeling statistics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_no_db(self):
|
||||
"""Test stats when database is not available."""
|
||||
from metrics_db import get_ocr_labeling_stats
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
|
||||
stats = await get_ocr_labeling_stats()
|
||||
assert "error" in stats or stats.get("total_items", 0) == 0
|
||||
|
||||
def test_stats_response_model(self):
|
||||
"""Test StatsResponse model structure."""
|
||||
from ocr_labeling_api import StatsResponse
|
||||
|
||||
stats = StatsResponse(
|
||||
total_items=100,
|
||||
labeled_items=50,
|
||||
confirmed_items=40,
|
||||
corrected_items=10,
|
||||
pending_items=50,
|
||||
accuracy_rate=0.8,
|
||||
)
|
||||
|
||||
assert stats.total_items == 100
|
||||
assert stats.accuracy_rate == 0.8
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Export Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestTrainingExport:
|
||||
"""Tests for training data export."""
|
||||
|
||||
def test_export_request_validation(self):
|
||||
"""Test ExportRequest model validation."""
|
||||
from ocr_labeling_api import ExportRequest
|
||||
|
||||
# Default format is generic
|
||||
request = ExportRequest()
|
||||
assert request.export_format == "generic"
|
||||
|
||||
# TrOCR format
|
||||
request = ExportRequest(export_format="trocr")
|
||||
assert request.export_format == "trocr"
|
||||
|
||||
# Llama Vision format
|
||||
request = ExportRequest(export_format="llama_vision")
|
||||
assert request.export_format == "llama_vision"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_training_samples(self, mock_db_pool):
|
||||
"""Test training sample export from database."""
|
||||
from metrics_db import export_training_samples
|
||||
|
||||
pool, conn = mock_db_pool
|
||||
conn.fetch.return_value = [
|
||||
{
|
||||
"id": "sample-1",
|
||||
"image_path": "/app/ocr-labeling/session-123/item-1.png",
|
||||
"ground_truth": "Text 1",
|
||||
},
|
||||
{
|
||||
"id": "sample-2",
|
||||
"image_path": "/app/ocr-labeling/session-123/item-2.png",
|
||||
"ground_truth": "Text 2",
|
||||
},
|
||||
]
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=pool):
|
||||
samples = await export_training_samples(
|
||||
export_format="generic",
|
||||
exported_by="admin",
|
||||
)
|
||||
|
||||
# Should return exportable samples
|
||||
assert conn.fetch.called or conn.execute.called
|
||||
|
||||
|
||||
class TestTrainingExportService:
|
||||
"""Tests for training export service."""
|
||||
|
||||
def test_trocr_export_format(self):
|
||||
"""Test TrOCR export format structure."""
|
||||
expected_format = {
|
||||
"file_name": "images/sample-1.png",
|
||||
"text": "Ground truth text",
|
||||
"id": "sample-1",
|
||||
}
|
||||
|
||||
assert "file_name" in expected_format
|
||||
assert "text" in expected_format
|
||||
|
||||
def test_llama_vision_export_format(self):
|
||||
"""Test Llama Vision export format structure."""
|
||||
expected_format = {
|
||||
"id": "sample-1",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Du bist ein OCR-Experte..."},
|
||||
{"role": "user", "content": [
|
||||
{"type": "image_url", "image_url": {"url": "..."}},
|
||||
{"type": "text", "text": "Lies den Text..."},
|
||||
]},
|
||||
{"role": "assistant", "content": "Ground truth text"},
|
||||
],
|
||||
}
|
||||
|
||||
assert "messages" in expected_format
|
||||
assert len(expected_format["messages"]) == 3
|
||||
assert expected_format["messages"][2]["role"] == "assistant"
|
||||
|
||||
def test_generic_export_format(self):
|
||||
"""Test generic export format structure."""
|
||||
expected_format = {
|
||||
"id": "sample-1",
|
||||
"image_path": "images/sample-1.png",
|
||||
"ground_truth": "Ground truth text",
|
||||
"ocr_text": "OCR recognized text",
|
||||
"ocr_confidence": 0.87,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
assert "image_path" in expected_format
|
||||
assert "ground_truth" in expected_format
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OCR Processing Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestOCRProcessing:
|
||||
"""Tests for OCR processing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_ocr_on_image_no_service(self):
|
||||
"""Test OCR when service is not available."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
with patch('ocr_labeling_api.VISION_OCR_AVAILABLE', False), \
|
||||
patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False), \
|
||||
patch('ocr_labeling_api.TROCR_AVAILABLE', False), \
|
||||
patch('ocr_labeling_api.DONUT_AVAILABLE', False):
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
)
|
||||
|
||||
assert text is None
|
||||
assert confidence == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_ocr_on_image_success(self, mock_vision_ocr):
|
||||
"""Test successful OCR processing."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
)
|
||||
|
||||
assert text == "Erkannter Text aus dem Bild"
|
||||
assert confidence == 0.87
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OCR Model Dispatcher Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestOCRModelDispatcher:
|
||||
"""Tests for the OCR model dispatcher (v1.1.0)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_vision_model_default(self, mock_vision_ocr):
|
||||
"""Test dispatcher uses Vision OCR by default."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="llama3.2-vision:11b",
|
||||
)
|
||||
|
||||
assert text == "Erkannter Text aus dem Bild"
|
||||
assert confidence == 0.87
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_paddleocr_model(self):
|
||||
"""Test dispatcher routes to PaddleOCR."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
# Mock PaddleOCR
|
||||
mock_regions = []
|
||||
mock_text = "PaddleOCR erkannter Text"
|
||||
|
||||
with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.run_paddle_ocr', return_value=(mock_regions, mock_text)):
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="paddleocr",
|
||||
)
|
||||
|
||||
assert text == "PaddleOCR erkannter Text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_paddleocr_fallback_to_vision(self, mock_vision_ocr):
|
||||
"""Test PaddleOCR falls back to Vision OCR when unavailable."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
with patch('ocr_labeling_api.PADDLEOCR_AVAILABLE', False):
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="paddleocr",
|
||||
)
|
||||
|
||||
# Should fall back to Vision OCR
|
||||
assert text == "Erkannter Text aus dem Bild"
|
||||
assert confidence == 0.87
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_trocr_model(self):
|
||||
"""Test dispatcher routes to TrOCR."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
async def mock_trocr(image_data):
|
||||
return "TrOCR erkannter Text", 0.85
|
||||
|
||||
with patch('ocr_labeling_api.TROCR_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.run_trocr_ocr', mock_trocr):
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="trocr",
|
||||
)
|
||||
|
||||
assert text == "TrOCR erkannter Text"
|
||||
assert confidence == 0.85
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_donut_model(self):
|
||||
"""Test dispatcher routes to Donut."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
async def mock_donut(image_data):
|
||||
return "Donut erkannter Text", 0.80
|
||||
|
||||
with patch('ocr_labeling_api.DONUT_AVAILABLE', True), \
|
||||
patch('ocr_labeling_api.run_donut_ocr', mock_donut):
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="donut",
|
||||
)
|
||||
|
||||
assert text == "Donut erkannter Text"
|
||||
assert confidence == 0.80
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatcher_unknown_model_uses_vision(self, mock_vision_ocr):
|
||||
"""Test dispatcher uses Vision OCR for unknown models."""
|
||||
from ocr_labeling_api import run_ocr_on_image
|
||||
|
||||
text, confidence = await run_ocr_on_image(
|
||||
image_data=b"fake image",
|
||||
filename="test.png",
|
||||
model="unknown-model",
|
||||
)
|
||||
|
||||
# Unknown model should fall back to Vision OCR
|
||||
assert text == "Erkannter Text aus dem Bild"
|
||||
assert confidence == 0.87
|
||||
|
||||
|
||||
class TestOCRModelTypes:
|
||||
"""Tests for OCR model type definitions."""
|
||||
|
||||
def test_session_with_paddleocr_model(self):
|
||||
"""Test session creation with PaddleOCR model."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
session = SessionCreate(
|
||||
name="PaddleOCR Session",
|
||||
source_type="klausur",
|
||||
ocr_model="paddleocr",
|
||||
)
|
||||
|
||||
assert session.ocr_model == "paddleocr"
|
||||
|
||||
def test_session_with_donut_model(self):
|
||||
"""Test session creation with Donut model."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
session = SessionCreate(
|
||||
name="Donut Session",
|
||||
source_type="scan",
|
||||
ocr_model="donut",
|
||||
)
|
||||
|
||||
assert session.ocr_model == "donut"
|
||||
|
||||
def test_session_with_trocr_model(self):
|
||||
"""Test session creation with TrOCR model."""
|
||||
from ocr_labeling_api import SessionCreate
|
||||
|
||||
session = SessionCreate(
|
||||
name="TrOCR Session",
|
||||
source_type="handwriting_sample",
|
||||
ocr_model="trocr",
|
||||
)
|
||||
|
||||
assert session.ocr_model == "trocr"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Response Model Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestResponseModels:
|
||||
"""Tests for API response models."""
|
||||
|
||||
def test_session_response_model(self):
|
||||
"""Test SessionResponse model."""
|
||||
from ocr_labeling_api import SessionResponse
|
||||
|
||||
session = SessionResponse(
|
||||
id="session-123",
|
||||
name="Test Session",
|
||||
source_type="klausur",
|
||||
description="Test",
|
||||
ocr_model="llama3.2-vision:11b",
|
||||
total_items=10,
|
||||
labeled_items=5,
|
||||
confirmed_items=3,
|
||||
corrected_items=2,
|
||||
skipped_items=0,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert session.id == "session-123"
|
||||
assert session.total_items == 10
|
||||
|
||||
def test_item_response_model(self):
|
||||
"""Test ItemResponse model."""
|
||||
from ocr_labeling_api import ItemResponse
|
||||
|
||||
item = ItemResponse(
|
||||
id="item-456",
|
||||
session_id="session-123",
|
||||
session_name="Test Session",
|
||||
image_path="/app/ocr-labeling/session-123/item-456.png",
|
||||
image_url="/api/v1/ocr-label/images/session-123/item-456.png",
|
||||
ocr_text="Test OCR text",
|
||||
ocr_confidence=0.87,
|
||||
ground_truth=None,
|
||||
status="pending",
|
||||
metadata={"page": 1},
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
assert item.id == "item-456"
|
||||
assert item.status == "pending"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Deduplication Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestDeduplication:
|
||||
"""Tests for image deduplication."""
|
||||
|
||||
def test_hash_based_deduplication(self):
|
||||
"""Test that same images produce same hash for deduplication."""
|
||||
from ocr_labeling_api import compute_image_hash
|
||||
|
||||
# Same content should be detected as duplicate
|
||||
image1 = b"\x89PNG\x0d\x0a\x1a\x0a test image content"
|
||||
image2 = b"\x89PNG\x0d\x0a\x1a\x0a test image content"
|
||||
|
||||
hash1 = compute_image_hash(image1)
|
||||
hash2 = compute_image_hash(image2)
|
||||
|
||||
assert hash1 == hash2
|
||||
|
||||
def test_unique_images_different_hash(self):
|
||||
"""Test that different images produce different hashes."""
|
||||
from ocr_labeling_api import compute_image_hash
|
||||
|
||||
image1 = b"\x89PNG unique content 1"
|
||||
image2 = b"\x89PNG unique content 2"
|
||||
|
||||
hash1 = compute_image_hash(image1)
|
||||
hash2 = compute_image_hash(image2)
|
||||
|
||||
assert hash1 != hash2
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests (require running services)
|
||||
# =============================================================================
|
||||
|
||||
class TestOCRLabelingIntegration:
|
||||
"""Integration tests - require Ollama, MinIO, PostgreSQL running."""
|
||||
|
||||
@pytest.mark.skip(reason="Requires running Ollama with llama3.2-vision")
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_labeling_workflow(self):
|
||||
"""Test complete labeling workflow."""
|
||||
# This would require:
|
||||
# 1. Create session
|
||||
# 2. Upload image
|
||||
# 3. Run OCR
|
||||
# 4. Confirm or correct label
|
||||
# 5. Export training data
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Requires running PostgreSQL")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stats_calculation(self):
|
||||
"""Test statistics calculation with real data."""
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Requires running MinIO")
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_storage_and_retrieval(self):
|
||||
"""Test image upload and download from MinIO."""
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Run Tests
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
356
klausur-service/backend/tests/test_rag_admin.py
Normal file
356
klausur-service/backend/tests/test_rag_admin.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""
|
||||
Tests for RAG Admin API
|
||||
Tests upload, search, metrics, and storage functionality.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime
|
||||
import io
|
||||
import zipfile
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Mock Qdrant client."""
|
||||
with patch('admin_api.get_qdrant_client') as mock:
|
||||
client = MagicMock()
|
||||
client.get_collections.return_value.collections = []
|
||||
client.get_collection.return_value.vectors_count = 7352
|
||||
client.get_collection.return_value.points_count = 7352
|
||||
client.get_collection.return_value.status.value = "green"
|
||||
mock.return_value = client
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_minio_client():
|
||||
"""Mock MinIO client."""
|
||||
with patch('minio_storage._get_minio_client') as mock:
|
||||
client = MagicMock()
|
||||
client.bucket_exists.return_value = True
|
||||
client.list_objects.return_value = []
|
||||
mock.return_value = client
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_pool():
|
||||
"""Mock PostgreSQL connection pool."""
|
||||
with patch('metrics_db.get_pool') as mock:
|
||||
pool = AsyncMock()
|
||||
mock.return_value = pool
|
||||
yield pool
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Admin API Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestIngestionStatus:
|
||||
"""Tests for /api/v1/admin/nibis/status endpoint."""
|
||||
|
||||
def test_status_not_running(self):
|
||||
"""Test status when no ingestion is running."""
|
||||
from admin_api import _ingestion_status
|
||||
|
||||
# Reset status
|
||||
_ingestion_status["running"] = False
|
||||
_ingestion_status["last_run"] = None
|
||||
_ingestion_status["last_result"] = None
|
||||
|
||||
assert _ingestion_status["running"] is False
|
||||
|
||||
def test_status_running(self):
|
||||
"""Test status when ingestion is running."""
|
||||
from admin_api import _ingestion_status
|
||||
|
||||
_ingestion_status["running"] = True
|
||||
_ingestion_status["last_run"] = datetime.now().isoformat()
|
||||
|
||||
assert _ingestion_status["running"] is True
|
||||
assert _ingestion_status["last_run"] is not None
|
||||
|
||||
|
||||
class TestUploadAPI:
|
||||
"""Tests for /api/v1/admin/rag/upload endpoint."""
|
||||
|
||||
def test_upload_record_creation(self):
|
||||
"""Test that upload records are created correctly."""
|
||||
from admin_api import _upload_history
|
||||
|
||||
# Clear history
|
||||
_upload_history.clear()
|
||||
|
||||
# Simulate upload record
|
||||
upload_record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"filename": "test.pdf",
|
||||
"collection": "bp_nibis_eh",
|
||||
"year": 2024,
|
||||
"pdfs_extracted": 1,
|
||||
"target_directory": "/tmp/test",
|
||||
}
|
||||
_upload_history.append(upload_record)
|
||||
|
||||
assert len(_upload_history) == 1
|
||||
assert _upload_history[0]["filename"] == "test.pdf"
|
||||
|
||||
def test_upload_history_limit(self):
|
||||
"""Test that upload history is limited to 100 entries."""
|
||||
from admin_api import _upload_history
|
||||
|
||||
_upload_history.clear()
|
||||
|
||||
# Add 105 entries
|
||||
for i in range(105):
|
||||
_upload_history.append({
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"filename": f"test_{i}.pdf",
|
||||
})
|
||||
if len(_upload_history) > 100:
|
||||
_upload_history.pop(0)
|
||||
|
||||
assert len(_upload_history) == 100
|
||||
|
||||
|
||||
class TestSearchFeedback:
|
||||
"""Tests for feedback storage."""
|
||||
|
||||
def test_feedback_record_format(self):
|
||||
"""Test feedback record structure."""
|
||||
feedback_record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"result_id": "test-123",
|
||||
"rating": 4,
|
||||
"notes": "Good result",
|
||||
}
|
||||
|
||||
assert "timestamp" in feedback_record
|
||||
assert feedback_record["rating"] >= 1
|
||||
assert feedback_record["rating"] <= 5
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MinIO Storage Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestMinIOStorage:
|
||||
"""Tests for MinIO storage functions."""
|
||||
|
||||
def test_get_minio_path(self):
|
||||
"""Test MinIO path generation."""
|
||||
from minio_storage import get_minio_path
|
||||
|
||||
path = get_minio_path(
|
||||
data_type="landes-daten",
|
||||
bundesland="ni",
|
||||
use_case="klausur",
|
||||
year=2024,
|
||||
filename="test.pdf",
|
||||
)
|
||||
|
||||
assert path == "landes-daten/ni/klausur/2024/test.pdf"
|
||||
|
||||
def test_get_minio_path_teacher_data(self):
|
||||
"""Test MinIO path for teacher data."""
|
||||
from minio_storage import get_minio_path
|
||||
|
||||
# Teacher data uses different path structure
|
||||
path = f"lehrer-daten/tenant_123/teacher_456/test.pdf.enc"
|
||||
|
||||
assert "lehrer-daten" in path
|
||||
assert "tenant_123" in path
|
||||
assert ".enc" in path
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_storage_stats_no_client(self):
|
||||
"""Test storage stats when MinIO is not available."""
|
||||
from minio_storage import get_storage_stats
|
||||
|
||||
with patch('minio_storage._get_minio_client', return_value=None):
|
||||
stats = await get_storage_stats()
|
||||
assert stats["connected"] is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metrics DB Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestMetricsDB:
|
||||
"""Tests for PostgreSQL metrics functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_feedback_no_pool(self):
|
||||
"""Test feedback storage when DB is not available."""
|
||||
from metrics_db import store_feedback
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
|
||||
result = await store_feedback(
|
||||
result_id="test-123",
|
||||
rating=4,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_metrics_no_pool(self):
|
||||
"""Test metrics calculation when DB is not available."""
|
||||
from metrics_db import calculate_metrics
|
||||
|
||||
with patch('metrics_db.get_pool', new_callable=AsyncMock, return_value=None):
|
||||
metrics = await calculate_metrics()
|
||||
assert metrics["connected"] is False
|
||||
|
||||
def test_create_tables_sql_structure(self):
|
||||
"""Test that SQL table creation is properly structured."""
|
||||
expected_tables = [
|
||||
"rag_search_feedback",
|
||||
"rag_search_logs",
|
||||
"rag_upload_history",
|
||||
]
|
||||
|
||||
# Read the metrics_db module to check table names
|
||||
from metrics_db import init_metrics_tables
|
||||
|
||||
# The function should create these tables
|
||||
assert callable(init_metrics_tables)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration Tests (require running services)
|
||||
# =============================================================================
|
||||
|
||||
class TestRAGIntegration:
|
||||
"""Integration tests - require Qdrant, MinIO, PostgreSQL running."""
|
||||
|
||||
@pytest.mark.skip(reason="Requires running Qdrant")
|
||||
@pytest.mark.asyncio
|
||||
async def test_nibis_search(self):
|
||||
"""Test NiBiS semantic search."""
|
||||
from admin_api import search_nibis
|
||||
from admin_api import NiBiSSearchRequest
|
||||
|
||||
request = NiBiSSearchRequest(
|
||||
query="Gedichtanalyse Expressionismus",
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# This would require Qdrant running
|
||||
# results = await search_nibis(request)
|
||||
# assert len(results) <= 5
|
||||
|
||||
@pytest.mark.skip(reason="Requires running MinIO")
|
||||
@pytest.mark.asyncio
|
||||
async def test_minio_upload(self):
|
||||
"""Test MinIO document upload."""
|
||||
from minio_storage import upload_rag_document
|
||||
|
||||
test_content = b"%PDF-1.4 test content"
|
||||
|
||||
# This would require MinIO running
|
||||
# path = await upload_rag_document(
|
||||
# file_data=test_content,
|
||||
# filename="test.pdf",
|
||||
# bundesland="ni",
|
||||
# use_case="klausur",
|
||||
# year=2024,
|
||||
# )
|
||||
# assert path is not None
|
||||
|
||||
@pytest.mark.skip(reason="Requires running PostgreSQL")
|
||||
@pytest.mark.asyncio
|
||||
async def test_metrics_storage(self):
|
||||
"""Test metrics storage in PostgreSQL."""
|
||||
from metrics_db import store_feedback, calculate_metrics
|
||||
|
||||
# This would require PostgreSQL running
|
||||
# stored = await store_feedback(
|
||||
# result_id="test-123",
|
||||
# rating=4,
|
||||
# query_text="test query",
|
||||
# )
|
||||
# assert stored is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# ZIP Handling Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestZIPHandling:
|
||||
"""Tests for ZIP file extraction."""
|
||||
|
||||
def test_create_test_zip(self):
|
||||
"""Test creating a ZIP file in memory."""
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr("test1.pdf", b"%PDF-1.4 test content 1")
|
||||
zf.writestr("test2.pdf", b"%PDF-1.4 test content 2")
|
||||
zf.writestr("subfolder/test3.pdf", b"%PDF-1.4 test content 3")
|
||||
|
||||
zip_buffer.seek(0)
|
||||
|
||||
# Verify ZIP contents
|
||||
with zipfile.ZipFile(zip_buffer, 'r') as zf:
|
||||
names = zf.namelist()
|
||||
assert "test1.pdf" in names
|
||||
assert "test2.pdf" in names
|
||||
assert "subfolder/test3.pdf" in names
|
||||
|
||||
def test_filter_macosx_files(self):
|
||||
"""Test filtering out __MACOSX files from ZIP."""
|
||||
zip_buffer = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
|
||||
zf.writestr("test.pdf", b"%PDF-1.4 test")
|
||||
zf.writestr("__MACOSX/._test.pdf", b"macosx metadata")
|
||||
|
||||
zip_buffer.seek(0)
|
||||
|
||||
with zipfile.ZipFile(zip_buffer, 'r') as zf:
|
||||
pdfs = [
|
||||
name for name in zf.namelist()
|
||||
if name.lower().endswith(".pdf") and not name.startswith("__MACOSX")
|
||||
]
|
||||
|
||||
assert len(pdfs) == 1
|
||||
assert pdfs[0] == "test.pdf"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Embedding Tests
|
||||
# =============================================================================
|
||||
|
||||
class TestEmbeddings:
|
||||
"""Tests for embedding generation."""
|
||||
|
||||
def test_vector_dimensions(self):
|
||||
"""Test that vector dimensions are configured correctly."""
|
||||
from eh_pipeline import get_vector_size, EMBEDDING_BACKEND
|
||||
|
||||
size = get_vector_size()
|
||||
|
||||
if EMBEDDING_BACKEND == "local":
|
||||
assert size == 384 # all-MiniLM-L6-v2
|
||||
elif EMBEDDING_BACKEND == "openai":
|
||||
assert size == 1536 # text-embedding-3-small
|
||||
|
||||
def test_chunking_config(self):
|
||||
"""Test chunking configuration."""
|
||||
from eh_pipeline import CHUNK_SIZE, CHUNK_OVERLAP
|
||||
|
||||
assert CHUNK_SIZE > 0
|
||||
assert CHUNK_OVERLAP >= 0
|
||||
assert CHUNK_OVERLAP < CHUNK_SIZE
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Run Tests
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
1236
klausur-service/backend/tests/test_rbac.py
Normal file
1236
klausur-service/backend/tests/test_rbac.py
Normal file
File diff suppressed because it is too large
Load Diff
623
klausur-service/backend/tests/test_vocab_worksheet.py
Normal file
623
klausur-service/backend/tests/test_vocab_worksheet.py
Normal file
@@ -0,0 +1,623 @@
|
||||
"""
|
||||
Unit Tests for Vocab-Worksheet API
|
||||
|
||||
Tests cover:
|
||||
- Session CRUD (create, read, list, delete)
|
||||
- File upload (images and PDFs)
|
||||
- PDF page handling (thumbnails, page selection)
|
||||
- Vocabulary extraction (mocked Vision LLM)
|
||||
- Vocabulary editing
|
||||
- Worksheet generation
|
||||
- PDF export
|
||||
|
||||
DSGVO Note: All tests run locally without external API calls.
|
||||
|
||||
BACKLOG: Feature not yet integrated into main.py
|
||||
See: https://macmini:3002/infrastructure/tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
# Skip entire module if vocab_worksheet_api is not available
|
||||
pytest.importorskip("vocab_worksheet_api", reason="vocab_worksheet_api not yet integrated - Backlog item")
|
||||
|
||||
# Mark all tests in this module as expected failures (backlog item)
|
||||
pytestmark = pytest.mark.xfail(
|
||||
reason="vocab_worksheet_api not yet integrated into main.py - Backlog item",
|
||||
strict=False # Don't fail if test unexpectedly passes
|
||||
)
|
||||
import json
|
||||
import uuid
|
||||
import io
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Import the main app and vocab-worksheet components
|
||||
sys.path.insert(0, '..')
|
||||
from main import app
|
||||
from vocab_worksheet_api import (
|
||||
_sessions,
|
||||
_worksheets,
|
||||
SessionStatus,
|
||||
WorksheetType,
|
||||
VocabularyEntry,
|
||||
SessionCreate,
|
||||
VocabularyUpdate,
|
||||
WorksheetGenerateRequest,
|
||||
parse_vocabulary_json,
|
||||
)
|
||||
|
||||
|
||||
# =============================================
|
||||
# FIXTURES
|
||||
# =============================================
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client for FastAPI app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_storage():
|
||||
"""Clear in-memory storage before each test."""
|
||||
_sessions.clear()
|
||||
_worksheets.clear()
|
||||
yield
|
||||
_sessions.clear()
|
||||
_worksheets.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session_data():
|
||||
"""Sample session creation data."""
|
||||
return {
|
||||
"name": "Englisch Klasse 7 - Unit 3",
|
||||
"description": "Vokabeln aus Green Line 3",
|
||||
"source_language": "en",
|
||||
"target_language": "de"
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_vocabulary():
|
||||
"""Sample vocabulary entries."""
|
||||
return [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"english": "to achieve",
|
||||
"german": "erreichen, erzielen",
|
||||
"example_sentence": "She achieved her goals.",
|
||||
"word_type": "v",
|
||||
"source_page": 1
|
||||
},
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"english": "achievement",
|
||||
"german": "Leistung, Errungenschaft",
|
||||
"example_sentence": "That was a great achievement.",
|
||||
"word_type": "n",
|
||||
"source_page": 1
|
||||
},
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"english": "improve",
|
||||
"german": "verbessern",
|
||||
"example_sentence": "I want to improve my English.",
|
||||
"word_type": "v",
|
||||
"source_page": 1
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_bytes():
|
||||
"""Create a minimal valid PNG image (1x1 pixel, white)."""
|
||||
# Minimal PNG: 1x1 white pixel
|
||||
png_data = bytes([
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, # PNG signature
|
||||
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, # IHDR chunk
|
||||
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, # 1x1
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53,
|
||||
0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, # IDAT chunk
|
||||
0x54, 0x08, 0xD7, 0x63, 0xF8, 0xFF, 0xFF, 0xFF,
|
||||
0x00, 0x05, 0xFE, 0x02, 0xFE, 0xDC, 0xCC, 0x59,
|
||||
0xE7, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, # IEND chunk
|
||||
0x44, 0xAE, 0x42, 0x60, 0x82
|
||||
])
|
||||
return png_data
|
||||
|
||||
|
||||
# =============================================
|
||||
# SESSION TESTS
|
||||
# =============================================
|
||||
|
||||
class TestSessionCRUD:
|
||||
"""Test session create, read, update, delete operations."""
|
||||
|
||||
def test_create_session(self, client, sample_session_data):
|
||||
"""Test creating a new vocabulary session."""
|
||||
response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "id" in data
|
||||
assert data["name"] == sample_session_data["name"]
|
||||
assert data["description"] == sample_session_data["description"]
|
||||
assert data["source_language"] == "en"
|
||||
assert data["target_language"] == "de"
|
||||
assert data["status"] == "pending"
|
||||
assert data["vocabulary_count"] == 0
|
||||
|
||||
def test_create_session_minimal(self, client):
|
||||
"""Test creating session with minimal data."""
|
||||
response = client.post("/api/v1/vocab/sessions", json={"name": "Test"})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test"
|
||||
assert data["source_language"] == "en" # Default
|
||||
assert data["target_language"] == "de" # Default
|
||||
|
||||
def test_list_sessions_empty(self, client):
|
||||
"""Test listing sessions when none exist."""
|
||||
response = client.get("/api/v1/vocab/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_list_sessions(self, client, sample_session_data):
|
||||
"""Test listing sessions after creating some."""
|
||||
# Create 3 sessions
|
||||
for i in range(3):
|
||||
data = sample_session_data.copy()
|
||||
data["name"] = f"Session {i+1}"
|
||||
client.post("/api/v1/vocab/sessions", json=data)
|
||||
|
||||
response = client.get("/api/v1/vocab/sessions")
|
||||
|
||||
assert response.status_code == 200
|
||||
sessions = response.json()
|
||||
assert len(sessions) == 3
|
||||
|
||||
def test_get_session(self, client, sample_session_data):
|
||||
"""Test getting a specific session."""
|
||||
# Create session
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Get session
|
||||
response = client.get(f"/api/v1/vocab/sessions/{session_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == session_id
|
||||
assert data["name"] == sample_session_data["name"]
|
||||
|
||||
def test_get_session_not_found(self, client):
|
||||
"""Test getting non-existent session."""
|
||||
fake_id = str(uuid.uuid4())
|
||||
response = client.get(f"/api/v1/vocab/sessions/{fake_id}")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_delete_session(self, client, sample_session_data):
|
||||
"""Test deleting a session."""
|
||||
# Create session
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Delete session
|
||||
response = client.delete(f"/api/v1/vocab/sessions/{session_id}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "deleted" in response.json()["message"].lower()
|
||||
|
||||
# Verify it's gone
|
||||
get_response = client.get(f"/api/v1/vocab/sessions/{session_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_delete_session_not_found(self, client):
|
||||
"""Test deleting non-existent session."""
|
||||
fake_id = str(uuid.uuid4())
|
||||
response = client.delete(f"/api/v1/vocab/sessions/{fake_id}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# =============================================
|
||||
# VOCABULARY TESTS
|
||||
# =============================================
|
||||
|
||||
class TestVocabulary:
|
||||
"""Test vocabulary operations."""
|
||||
|
||||
def test_get_vocabulary_empty(self, client, sample_session_data):
|
||||
"""Test getting vocabulary from session with no vocabulary."""
|
||||
# Create session
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Get vocabulary
|
||||
response = client.get(f"/api/v1/vocab/sessions/{session_id}/vocabulary")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["session_id"] == session_id
|
||||
assert data["vocabulary"] == []
|
||||
|
||||
def test_update_vocabulary(self, client, sample_session_data, sample_vocabulary):
|
||||
"""Test updating vocabulary entries."""
|
||||
# Create session
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Update vocabulary
|
||||
response = client.put(
|
||||
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
|
||||
json={"vocabulary": sample_vocabulary}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["vocabulary_count"] == 3
|
||||
|
||||
# Verify vocabulary was saved
|
||||
get_response = client.get(f"/api/v1/vocab/sessions/{session_id}/vocabulary")
|
||||
vocab_data = get_response.json()
|
||||
assert len(vocab_data["vocabulary"]) == 3
|
||||
|
||||
def test_update_vocabulary_not_found(self, client, sample_vocabulary):
|
||||
"""Test updating vocabulary for non-existent session."""
|
||||
fake_id = str(uuid.uuid4())
|
||||
response = client.put(
|
||||
f"/api/v1/vocab/sessions/{fake_id}/vocabulary",
|
||||
json={"vocabulary": sample_vocabulary}
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# =============================================
|
||||
# WORKSHEET GENERATION TESTS
|
||||
# =============================================
|
||||
|
||||
class TestWorksheetGeneration:
|
||||
"""Test worksheet generation."""
|
||||
|
||||
def test_generate_worksheet_no_vocabulary(self, client, sample_session_data):
|
||||
"""Test generating worksheet without vocabulary fails."""
|
||||
# Create session (no vocabulary)
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Try to generate worksheet
|
||||
response = client.post(
|
||||
f"/api/v1/vocab/sessions/{session_id}/generate",
|
||||
json={
|
||||
"worksheet_types": ["en_to_de"],
|
||||
"include_solutions": True
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "no vocabulary" in response.json()["detail"].lower()
|
||||
|
||||
@patch('vocab_worksheet_api.generate_worksheet_pdf')
|
||||
def test_generate_worksheet_success(
|
||||
self, mock_pdf, client, sample_session_data, sample_vocabulary
|
||||
):
|
||||
"""Test successful worksheet generation."""
|
||||
# Mock PDF generation to return fake bytes
|
||||
mock_pdf.return_value = b"%PDF-1.4 fake pdf content"
|
||||
|
||||
# Create session with vocabulary
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Add vocabulary
|
||||
client.put(
|
||||
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
|
||||
json={"vocabulary": sample_vocabulary}
|
||||
)
|
||||
|
||||
# Generate worksheet
|
||||
response = client.post(
|
||||
f"/api/v1/vocab/sessions/{session_id}/generate",
|
||||
json={
|
||||
"worksheet_types": ["en_to_de", "de_to_en"],
|
||||
"include_solutions": True,
|
||||
"line_height": "large"
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert data["session_id"] == session_id
|
||||
assert "en_to_de" in data["worksheet_types"]
|
||||
assert "de_to_en" in data["worksheet_types"]
|
||||
|
||||
def test_generate_worksheet_all_types(self, client, sample_session_data, sample_vocabulary):
|
||||
"""Test that all worksheet types are accepted."""
|
||||
# Create session with vocabulary
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Add vocabulary
|
||||
client.put(
|
||||
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
|
||||
json={"vocabulary": sample_vocabulary}
|
||||
)
|
||||
|
||||
# Test each worksheet type
|
||||
for wtype in ["en_to_de", "de_to_en", "copy", "gap_fill"]:
|
||||
with patch('vocab_worksheet_api.generate_worksheet_pdf') as mock_pdf:
|
||||
mock_pdf.return_value = b"%PDF-1.4 fake"
|
||||
response = client.post(
|
||||
f"/api/v1/vocab/sessions/{session_id}/generate",
|
||||
json={"worksheet_types": [wtype]}
|
||||
)
|
||||
assert response.status_code == 200, f"Failed for type: {wtype}"
|
||||
|
||||
|
||||
# =============================================
|
||||
# JSON PARSING TESTS
|
||||
# =============================================
|
||||
|
||||
class TestJSONParsing:
|
||||
"""Test vocabulary JSON parsing from LLM responses."""
|
||||
|
||||
def test_parse_valid_json(self):
|
||||
"""Test parsing valid JSON response."""
|
||||
response = '''
|
||||
{
|
||||
"vocabulary": [
|
||||
{"english": "achieve", "german": "erreichen"},
|
||||
{"english": "improve", "german": "verbessern"}
|
||||
]
|
||||
}
|
||||
'''
|
||||
result = parse_vocabulary_json(response)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].english == "achieve"
|
||||
assert result[0].german == "erreichen"
|
||||
|
||||
def test_parse_json_with_extra_text(self):
|
||||
"""Test parsing JSON with surrounding text."""
|
||||
response = '''
|
||||
Here is the extracted vocabulary:
|
||||
|
||||
{
|
||||
"vocabulary": [
|
||||
{"english": "success", "german": "Erfolg"}
|
||||
]
|
||||
}
|
||||
|
||||
I found 1 vocabulary entry.
|
||||
'''
|
||||
result = parse_vocabulary_json(response)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].english == "success"
|
||||
|
||||
def test_parse_json_with_examples(self):
|
||||
"""Test parsing JSON with example sentences."""
|
||||
response = '''
|
||||
{
|
||||
"vocabulary": [
|
||||
{
|
||||
"english": "achieve",
|
||||
"german": "erreichen",
|
||||
"example": "She achieved her goals."
|
||||
}
|
||||
]
|
||||
}
|
||||
'''
|
||||
result = parse_vocabulary_json(response)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].example_sentence == "She achieved her goals."
|
||||
|
||||
def test_parse_empty_response(self):
|
||||
"""Test parsing empty/invalid response."""
|
||||
result = parse_vocabulary_json("")
|
||||
assert result == []
|
||||
|
||||
result = parse_vocabulary_json("no json here")
|
||||
assert result == []
|
||||
|
||||
def test_parse_json_missing_fields(self):
|
||||
"""Test that entries without required fields are skipped."""
|
||||
response = '''
|
||||
{
|
||||
"vocabulary": [
|
||||
{"english": "valid", "german": "gueltig"},
|
||||
{"english": ""},
|
||||
{"german": "nur deutsch"},
|
||||
{"english": "also valid", "german": "auch gueltig"}
|
||||
]
|
||||
}
|
||||
'''
|
||||
result = parse_vocabulary_json(response)
|
||||
|
||||
# Only entries with both english and german should be included
|
||||
assert len(result) == 2
|
||||
assert result[0].english == "valid"
|
||||
assert result[1].english == "also valid"
|
||||
|
||||
|
||||
# =============================================
|
||||
# FILE UPLOAD TESTS
|
||||
# =============================================
|
||||
|
||||
class TestFileUpload:
|
||||
"""Test file upload functionality."""
|
||||
|
||||
@patch('vocab_worksheet_api.extract_vocabulary_from_image')
|
||||
def test_upload_image(self, mock_extract, client, sample_session_data, sample_image_bytes):
|
||||
"""Test uploading an image file."""
|
||||
# Mock extraction to return sample vocabulary
|
||||
mock_extract.return_value = (
|
||||
[
|
||||
VocabularyEntry(
|
||||
id=str(uuid.uuid4()),
|
||||
english="test",
|
||||
german="Test"
|
||||
)
|
||||
],
|
||||
0.85,
|
||||
""
|
||||
)
|
||||
|
||||
# Create session
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Upload image
|
||||
files = {"file": ("test.png", io.BytesIO(sample_image_bytes), "image/png")}
|
||||
response = client.post(
|
||||
f"/api/v1/vocab/sessions/{session_id}/upload",
|
||||
files=files
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["session_id"] == session_id
|
||||
assert data["vocabulary_count"] == 1
|
||||
|
||||
def test_upload_invalid_file_type(self, client, sample_session_data):
|
||||
"""Test uploading invalid file type."""
|
||||
# Create session
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Try to upload a text file
|
||||
files = {"file": ("test.txt", io.BytesIO(b"hello"), "text/plain")}
|
||||
response = client.post(
|
||||
f"/api/v1/vocab/sessions/{session_id}/upload",
|
||||
files=files
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "supported" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
# =============================================
|
||||
# STATUS WORKFLOW TESTS
|
||||
# =============================================
|
||||
|
||||
class TestSessionStatus:
|
||||
"""Test session status transitions."""
|
||||
|
||||
def test_initial_status_pending(self, client, sample_session_data):
|
||||
"""Test that new session has PENDING status."""
|
||||
response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
|
||||
assert response.json()["status"] == "pending"
|
||||
|
||||
@patch('vocab_worksheet_api.extract_vocabulary_from_image')
|
||||
def test_status_after_extraction(self, mock_extract, client, sample_session_data, sample_image_bytes):
|
||||
"""Test that status becomes EXTRACTED after processing."""
|
||||
mock_extract.return_value = ([], 0.0, "")
|
||||
|
||||
# Create and upload
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
files = {"file": ("test.png", io.BytesIO(sample_image_bytes), "image/png")}
|
||||
client.post(f"/api/v1/vocab/sessions/{session_id}/upload", files=files)
|
||||
|
||||
# Check status
|
||||
get_response = client.get(f"/api/v1/vocab/sessions/{session_id}")
|
||||
assert get_response.json()["status"] == "extracted"
|
||||
|
||||
@patch('vocab_worksheet_api.generate_worksheet_pdf')
|
||||
def test_status_after_generation(self, mock_pdf, client, sample_session_data, sample_vocabulary):
|
||||
"""Test that status becomes COMPLETED after worksheet generation."""
|
||||
mock_pdf.return_value = b"%PDF"
|
||||
|
||||
# Create session with vocabulary
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Add vocabulary
|
||||
client.put(
|
||||
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
|
||||
json={"vocabulary": sample_vocabulary}
|
||||
)
|
||||
|
||||
# Generate worksheet
|
||||
client.post(
|
||||
f"/api/v1/vocab/sessions/{session_id}/generate",
|
||||
json={"worksheet_types": ["en_to_de"]}
|
||||
)
|
||||
|
||||
# Check status
|
||||
get_response = client.get(f"/api/v1/vocab/sessions/{session_id}")
|
||||
assert get_response.json()["status"] == "completed"
|
||||
|
||||
|
||||
# =============================================
|
||||
# EDGE CASES
|
||||
# =============================================
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def test_session_with_special_characters(self, client):
|
||||
"""Test session with special characters in name."""
|
||||
response = client.post(
|
||||
"/api/v1/vocab/sessions",
|
||||
json={"name": "Englisch Klasse 7 - äöü ß € @"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert "äöü" in response.json()["name"]
|
||||
|
||||
def test_vocabulary_with_long_entries(self, client, sample_session_data):
|
||||
"""Test vocabulary with very long entries."""
|
||||
create_response = client.post("/api/v1/vocab/sessions", json=sample_session_data)
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
# Create vocabulary with long entries
|
||||
long_vocab = [{
|
||||
"id": str(uuid.uuid4()),
|
||||
"english": "a" * 100,
|
||||
"german": "b" * 200,
|
||||
"example_sentence": "c" * 500
|
||||
}]
|
||||
|
||||
response = client.put(
|
||||
f"/api/v1/vocab/sessions/{session_id}/vocabulary",
|
||||
json={"vocabulary": long_vocab}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_sessions_limit(self, client, sample_session_data):
|
||||
"""Test session listing with limit parameter."""
|
||||
# Create 10 sessions
|
||||
for i in range(10):
|
||||
data = sample_session_data.copy()
|
||||
data["name"] = f"Session {i+1}"
|
||||
client.post("/api/v1/vocab/sessions", json=data)
|
||||
|
||||
# Get with limit
|
||||
response = client.get("/api/v1/vocab/sessions?limit=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 5
|
||||
|
||||
|
||||
# =============================================
|
||||
# RUN TESTS
|
||||
# =============================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
539
klausur-service/backend/tests/test_worksheet_editor.py
Normal file
539
klausur-service/backend/tests/test_worksheet_editor.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""
|
||||
Unit Tests for Worksheet Editor API
|
||||
|
||||
Tests cover:
|
||||
- Worksheet CRUD (create, read, list, delete)
|
||||
- AI Image generation (mocked)
|
||||
- PDF Export
|
||||
- Health check
|
||||
|
||||
DSGVO Note: All tests run locally without external API calls.
|
||||
|
||||
BACKLOG: Feature not yet integrated into main.py
|
||||
See: https://macmini:3002/infrastructure/tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
# Skip entire module if worksheet_editor_api is not available
|
||||
pytest.importorskip("worksheet_editor_api", reason="worksheet_editor_api not yet integrated - Backlog item")
|
||||
|
||||
# Mark all tests in this module as expected failures (backlog item)
|
||||
pytestmark = pytest.mark.xfail(
|
||||
reason="worksheet_editor_api not yet integrated into main.py - Backlog item",
|
||||
strict=False # Don't fail if test unexpectedly passes
|
||||
)
|
||||
import json
|
||||
import uuid
|
||||
import io
|
||||
import os
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
# Import the main app and worksheet-editor components
|
||||
sys.path.insert(0, '..')
|
||||
from main import app
|
||||
from worksheet_editor_api import (
|
||||
worksheets_db,
|
||||
AIImageStyle,
|
||||
WorksheetStatus,
|
||||
WORKSHEET_STORAGE_DIR,
|
||||
)
|
||||
|
||||
|
||||
# =============================================
|
||||
# FIXTURES
|
||||
# =============================================
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client for FastAPI app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_storage():
|
||||
"""Clear in-memory storage before each test."""
|
||||
worksheets_db.clear()
|
||||
# Clean up test files
|
||||
if os.path.exists(WORKSHEET_STORAGE_DIR):
|
||||
for f in os.listdir(WORKSHEET_STORAGE_DIR):
|
||||
if f.startswith('test_') or f.startswith('ws_test_'):
|
||||
os.remove(os.path.join(WORKSHEET_STORAGE_DIR, f))
|
||||
yield
|
||||
worksheets_db.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_worksheet_data():
|
||||
"""Sample worksheet data."""
|
||||
return {
|
||||
"title": "Test Arbeitsblatt",
|
||||
"description": "Testbeschreibung",
|
||||
"pages": [
|
||||
{
|
||||
"id": "page_1",
|
||||
"index": 0,
|
||||
"canvasJSON": json.dumps({
|
||||
"objects": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Überschrift",
|
||||
"left": 100,
|
||||
"top": 50,
|
||||
"fontSize": 24
|
||||
},
|
||||
{
|
||||
"type": "rect",
|
||||
"left": 100,
|
||||
"top": 100,
|
||||
"width": 200,
|
||||
"height": 100
|
||||
}
|
||||
]
|
||||
})
|
||||
}
|
||||
],
|
||||
"pageFormat": {
|
||||
"width": 210,
|
||||
"height": 297,
|
||||
"orientation": "portrait",
|
||||
"margins": {"top": 15, "right": 15, "bottom": 15, "left": 15}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_multipage_worksheet():
|
||||
"""Sample worksheet with multiple pages."""
|
||||
return {
|
||||
"title": "Mehrseitiges Arbeitsblatt",
|
||||
"pages": [
|
||||
{"id": "page_1", "index": 0, "canvasJSON": "{}"},
|
||||
{"id": "page_2", "index": 1, "canvasJSON": "{}"},
|
||||
{"id": "page_3", "index": 2, "canvasJSON": "{}"},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# =============================================
|
||||
# WORKSHEET CRUD TESTS
|
||||
# =============================================
|
||||
|
||||
class TestWorksheetCRUD:
|
||||
"""Tests for Worksheet CRUD operations."""
|
||||
|
||||
def test_create_worksheet(self, client, sample_worksheet_data):
|
||||
"""Test creating a new worksheet."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=sample_worksheet_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "id" in data
|
||||
assert data["id"].startswith("ws_")
|
||||
assert data["title"] == sample_worksheet_data["title"]
|
||||
assert data["description"] == sample_worksheet_data["description"]
|
||||
assert len(data["pages"]) == 1
|
||||
assert "createdAt" in data
|
||||
assert "updatedAt" in data
|
||||
|
||||
def test_update_worksheet(self, client, sample_worksheet_data):
|
||||
"""Test updating an existing worksheet."""
|
||||
# First create
|
||||
create_response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=sample_worksheet_data
|
||||
)
|
||||
worksheet_id = create_response.json()["id"]
|
||||
|
||||
# Then update
|
||||
sample_worksheet_data["id"] = worksheet_id
|
||||
sample_worksheet_data["title"] = "Aktualisiertes Arbeitsblatt"
|
||||
|
||||
update_response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=sample_worksheet_data
|
||||
)
|
||||
|
||||
assert update_response.status_code == 200
|
||||
data = update_response.json()
|
||||
assert data["id"] == worksheet_id
|
||||
assert data["title"] == "Aktualisiertes Arbeitsblatt"
|
||||
|
||||
def test_get_worksheet(self, client, sample_worksheet_data):
|
||||
"""Test retrieving a worksheet by ID."""
|
||||
# Create first
|
||||
create_response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=sample_worksheet_data
|
||||
)
|
||||
worksheet_id = create_response.json()["id"]
|
||||
|
||||
# Then get
|
||||
get_response = client.get(f"/api/v1/worksheet/{worksheet_id}")
|
||||
|
||||
assert get_response.status_code == 200
|
||||
data = get_response.json()
|
||||
assert data["id"] == worksheet_id
|
||||
assert data["title"] == sample_worksheet_data["title"]
|
||||
|
||||
def test_get_nonexistent_worksheet(self, client):
|
||||
"""Test retrieving a non-existent worksheet."""
|
||||
response = client.get("/api/v1/worksheet/ws_nonexistent123")
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
def test_list_worksheets(self, client, sample_worksheet_data):
|
||||
"""Test listing all worksheets."""
|
||||
# Create multiple worksheets
|
||||
for i in range(3):
|
||||
data = sample_worksheet_data.copy()
|
||||
data["title"] = f"Arbeitsblatt {i+1}"
|
||||
client.post("/api/v1/worksheet/save", json=data)
|
||||
|
||||
# List all
|
||||
response = client.get("/api/v1/worksheet/list/all")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "worksheets" in data
|
||||
assert "total" in data
|
||||
assert data["total"] >= 3
|
||||
|
||||
def test_delete_worksheet(self, client, sample_worksheet_data):
|
||||
"""Test deleting a worksheet."""
|
||||
# Create first
|
||||
create_response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=sample_worksheet_data
|
||||
)
|
||||
worksheet_id = create_response.json()["id"]
|
||||
|
||||
# Delete
|
||||
delete_response = client.delete(f"/api/v1/worksheet/{worksheet_id}")
|
||||
|
||||
assert delete_response.status_code == 200
|
||||
assert delete_response.json()["status"] == "deleted"
|
||||
|
||||
# Verify it's gone
|
||||
get_response = client.get(f"/api/v1/worksheet/{worksheet_id}")
|
||||
assert get_response.status_code == 404
|
||||
|
||||
def test_delete_nonexistent_worksheet(self, client):
|
||||
"""Test deleting a non-existent worksheet."""
|
||||
response = client.delete("/api/v1/worksheet/ws_nonexistent123")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# =============================================
|
||||
# MULTIPAGE TESTS
|
||||
# =============================================
|
||||
|
||||
class TestMultipageWorksheet:
|
||||
"""Tests for multi-page worksheet functionality."""
|
||||
|
||||
def test_create_multipage_worksheet(self, client, sample_multipage_worksheet):
|
||||
"""Test creating a worksheet with multiple pages."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=sample_multipage_worksheet
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["pages"]) == 3
|
||||
|
||||
def test_page_indices(self, client, sample_multipage_worksheet):
|
||||
"""Test that page indices are preserved."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=sample_multipage_worksheet
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
for i, page in enumerate(data["pages"]):
|
||||
assert page["index"] == i
|
||||
|
||||
|
||||
# =============================================
|
||||
# AI IMAGE TESTS
|
||||
# =============================================
|
||||
|
||||
class TestAIImageGeneration:
|
||||
"""Tests for AI image generation."""
|
||||
|
||||
def test_ai_image_with_valid_prompt(self, client):
|
||||
"""Test AI image generation with valid prompt."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/ai-image",
|
||||
json={
|
||||
"prompt": "A friendly cartoon dog reading a book",
|
||||
"style": "cartoon",
|
||||
"width": 256,
|
||||
"height": 256
|
||||
}
|
||||
)
|
||||
|
||||
# Should return 200 (with placeholder) since Ollama may not be running
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "image_base64" in data
|
||||
assert data["image_base64"].startswith("data:image/png;base64,")
|
||||
assert "prompt_used" in data
|
||||
|
||||
def test_ai_image_styles(self, client):
|
||||
"""Test different AI image styles."""
|
||||
styles = ["realistic", "cartoon", "sketch", "clipart", "educational"]
|
||||
|
||||
for style in styles:
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/ai-image",
|
||||
json={
|
||||
"prompt": "Test image",
|
||||
"style": style,
|
||||
"width": 256,
|
||||
"height": 256
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert style in data["prompt_used"].lower() or "test image" in data["prompt_used"].lower()
|
||||
|
||||
def test_ai_image_empty_prompt(self, client):
|
||||
"""Test AI image generation with empty prompt."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/ai-image",
|
||||
json={
|
||||
"prompt": "",
|
||||
"style": "educational",
|
||||
"width": 256,
|
||||
"height": 256
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_ai_image_invalid_dimensions(self, client):
|
||||
"""Test AI image generation with invalid dimensions."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/ai-image",
|
||||
json={
|
||||
"prompt": "Test",
|
||||
"style": "educational",
|
||||
"width": 50, # Too small
|
||||
"height": 256
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
# =============================================
|
||||
# CANVAS JSON TESTS
|
||||
# =============================================
|
||||
|
||||
class TestCanvasJSON:
|
||||
"""Tests for canvas JSON handling."""
|
||||
|
||||
def test_save_and_load_canvas_json(self, client):
|
||||
"""Test that canvas JSON is preserved correctly."""
|
||||
canvas_data = {
|
||||
"objects": [
|
||||
{"type": "text", "text": "Hello", "left": 10, "top": 10},
|
||||
{"type": "rect", "left": 50, "top": 50, "width": 100, "height": 50}
|
||||
],
|
||||
"background": "#ffffff"
|
||||
}
|
||||
|
||||
worksheet_data = {
|
||||
"title": "Canvas Test",
|
||||
"pages": [{
|
||||
"id": "page_1",
|
||||
"index": 0,
|
||||
"canvasJSON": json.dumps(canvas_data)
|
||||
}]
|
||||
}
|
||||
|
||||
# Save
|
||||
create_response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=worksheet_data
|
||||
)
|
||||
worksheet_id = create_response.json()["id"]
|
||||
|
||||
# Load
|
||||
get_response = client.get(f"/api/v1/worksheet/{worksheet_id}")
|
||||
loaded_data = get_response.json()
|
||||
|
||||
# Compare
|
||||
loaded_canvas = json.loads(loaded_data["pages"][0]["canvasJSON"])
|
||||
assert loaded_canvas["objects"] == canvas_data["objects"]
|
||||
|
||||
def test_empty_canvas_json(self, client):
|
||||
"""Test worksheet with empty canvas JSON."""
|
||||
worksheet_data = {
|
||||
"title": "Empty Canvas",
|
||||
"pages": [{
|
||||
"id": "page_1",
|
||||
"index": 0,
|
||||
"canvasJSON": "{}"
|
||||
}]
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=worksheet_data
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# =============================================
|
||||
# PAGE FORMAT TESTS
|
||||
# =============================================
|
||||
|
||||
class TestPageFormat:
|
||||
"""Tests for page format handling."""
|
||||
|
||||
def test_default_page_format(self, client):
|
||||
"""Test that default page format is applied."""
|
||||
worksheet_data = {
|
||||
"title": "Default Format",
|
||||
"pages": [{"id": "p1", "index": 0, "canvasJSON": "{}"}]
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=worksheet_data
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["pageFormat"]["width"] == 210
|
||||
assert data["pageFormat"]["height"] == 297
|
||||
assert data["pageFormat"]["orientation"] == "portrait"
|
||||
|
||||
def test_custom_page_format(self, client):
|
||||
"""Test custom page format."""
|
||||
worksheet_data = {
|
||||
"title": "Custom Format",
|
||||
"pages": [{"id": "p1", "index": 0, "canvasJSON": "{}"}],
|
||||
"pageFormat": {
|
||||
"width": 297,
|
||||
"height": 210,
|
||||
"orientation": "landscape",
|
||||
"margins": {"top": 20, "right": 20, "bottom": 20, "left": 20}
|
||||
}
|
||||
}
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=worksheet_data
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
assert data["pageFormat"]["width"] == 297
|
||||
assert data["pageFormat"]["height"] == 210
|
||||
assert data["pageFormat"]["orientation"] == "landscape"
|
||||
|
||||
|
||||
# =============================================
|
||||
# HEALTH CHECK TESTS
|
||||
# =============================================
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
def test_health_check(self, client):
|
||||
"""Test health check endpoint."""
|
||||
response = client.get("/api/v1/worksheet/health/check")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert data["status"] == "healthy"
|
||||
assert "storage" in data
|
||||
assert "reportlab" in data
|
||||
|
||||
|
||||
# =============================================
|
||||
# PDF EXPORT TESTS
|
||||
# =============================================
|
||||
|
||||
class TestPDFExport:
|
||||
"""Tests for PDF export functionality."""
|
||||
|
||||
def test_export_pdf(self, client, sample_worksheet_data):
|
||||
"""Test PDF export of a worksheet."""
|
||||
# Create worksheet
|
||||
create_response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json=sample_worksheet_data
|
||||
)
|
||||
worksheet_id = create_response.json()["id"]
|
||||
|
||||
# Export
|
||||
export_response = client.post(
|
||||
f"/api/v1/worksheet/{worksheet_id}/export-pdf"
|
||||
)
|
||||
|
||||
# Check response (may be 501 if reportlab not installed)
|
||||
assert export_response.status_code in [200, 501]
|
||||
|
||||
if export_response.status_code == 200:
|
||||
assert export_response.headers["content-type"] == "application/pdf"
|
||||
assert "attachment" in export_response.headers.get("content-disposition", "")
|
||||
|
||||
def test_export_nonexistent_worksheet_pdf(self, client):
|
||||
"""Test PDF export of non-existent worksheet."""
|
||||
response = client.post("/api/v1/worksheet/ws_nonexistent/export-pdf")
|
||||
|
||||
assert response.status_code in [404, 501]
|
||||
|
||||
|
||||
# =============================================
|
||||
# ERROR HANDLING TESTS
|
||||
# =============================================
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Tests for error handling."""
|
||||
|
||||
def test_invalid_json(self, client):
|
||||
"""Test handling of invalid JSON."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
content="not valid json",
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_missing_required_fields(self, client):
|
||||
"""Test handling of missing required fields."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json={"pages": []} # Missing title
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_empty_pages_array(self, client):
|
||||
"""Test worksheet with empty pages array."""
|
||||
response = client.post(
|
||||
"/api/v1/worksheet/save",
|
||||
json={
|
||||
"title": "No Pages",
|
||||
"pages": []
|
||||
}
|
||||
)
|
||||
|
||||
# Should still work - empty worksheets are allowed
|
||||
assert response.status_code == 200
|
||||
Reference in New Issue
Block a user