""" Tests for Transcription Worker Components Tests for the transcription pipeline including: - Whisper transcription - Speaker diarization - Segment alignment - Export formats (VTT, SRT, JSON) """ import pytest from unittest.mock import Mock, patch, MagicMock from datetime import datetime import json class TestTranscriberModule: """Tests for the Whisper transcription module.""" def test_transcription_result_structure(self): """Test that transcription results have the expected structure.""" # Mock transcription result result = { "text": "Dies ist ein Test.", "segments": [ { "id": 0, "start": 0.0, "end": 2.5, "text": " Dies ist ein Test.", "avg_logprob": -0.25, "no_speech_prob": 0.01 } ], "language": "de" } assert "text" in result assert "segments" in result assert len(result["segments"]) > 0 assert "start" in result["segments"][0] assert "end" in result["segments"][0] assert "text" in result["segments"][0] def test_confidence_calculation(self): """Test confidence score calculation from log probabilities.""" # avg_logprob of -0.25 should give ~78% confidence import math avg_logprob = -0.25 confidence = math.exp(avg_logprob) assert 0.7 < confidence < 0.8 def test_segment_timing_validation(self): """Test that segment timings are valid.""" segments = [ {"start": 0.0, "end": 2.5, "text": "First segment"}, {"start": 2.5, "end": 5.0, "text": "Second segment"}, {"start": 5.0, "end": 7.5, "text": "Third segment"} ] for i, segment in enumerate(segments): # End must be after start assert segment["end"] > segment["start"] # No negative times assert segment["start"] >= 0 assert segment["end"] >= 0 # Segments should be sequential if i > 0: assert segment["start"] >= segments[i - 1]["end"] class TestDiarizationModule: """Tests for speaker diarization module.""" def test_speaker_segment_structure(self): """Test speaker segment structure.""" speaker_segments = [ {"speaker": "SPEAKER_00", "start": 0.0, "end": 3.0}, {"speaker": "SPEAKER_01", "start": 3.0, "end": 6.0}, {"speaker": "SPEAKER_00", "start": 6.0, "end": 9.0} ] for segment in speaker_segments: assert "speaker" in segment assert "start" in segment assert "end" in segment assert segment["speaker"].startswith("SPEAKER_") def test_multiple_speakers_detection(self): """Test that multiple speakers are detected.""" speaker_segments = [ {"speaker": "SPEAKER_00", "start": 0.0, "end": 3.0}, {"speaker": "SPEAKER_01", "start": 3.0, "end": 6.0}, {"speaker": "SPEAKER_02", "start": 6.0, "end": 9.0} ] unique_speakers = set(s["speaker"] for s in speaker_segments) assert len(unique_speakers) == 3 def test_overlapping_speech_handling(self): """Test handling of overlapping speech.""" # In real diarization, overlapping speech is split segments = [ {"speaker": "SPEAKER_00", "start": 0.0, "end": 5.0}, {"speaker": "SPEAKER_01", "start": 4.0, "end": 8.0} # Overlap at 4-5 ] # Detect overlap overlap_detected = False for i in range(len(segments) - 1): if segments[i]["end"] > segments[i + 1]["start"]: overlap_detected = True break assert overlap_detected class TestAlignmentModule: """Tests for text-speaker alignment.""" def test_align_transcription_with_speakers(self): """Test aligning transcription segments with speaker segments.""" transcription_segments = [ {"start": 0.0, "end": 2.0, "text": "Guten Tag."}, {"start": 2.0, "end": 4.0, "text": "Wie geht es Ihnen?"}, {"start": 4.0, "end": 6.0, "text": "Mir geht es gut, danke."} ] speaker_segments = [ {"speaker": "SPEAKER_00", "start": 0.0, "end": 2.5}, {"speaker": "SPEAKER_01", "start": 2.5, "end": 6.0} ] # Simple alignment by overlap def find_speaker(text_start, text_end, speakers): max_overlap = 0 best_speaker = None for sp in speakers: overlap_start = max(text_start, sp["start"]) overlap_end = min(text_end, sp["end"]) overlap = max(0, overlap_end - overlap_start) if overlap > max_overlap: max_overlap = overlap best_speaker = sp["speaker"] return best_speaker aligned = [] for seg in transcription_segments: speaker = find_speaker(seg["start"], seg["end"], speaker_segments) aligned.append({**seg, "speaker": speaker}) assert aligned[0]["speaker"] == "SPEAKER_00" assert aligned[1]["speaker"] == "SPEAKER_01" assert aligned[2]["speaker"] == "SPEAKER_01" class TestVTTExport: """Tests for WebVTT export format.""" def test_vtt_header(self): """Test VTT file starts with correct header.""" vtt_content = "WEBVTT\n\n00:00:00.000 --> 00:00:02.000\nTest" assert vtt_content.startswith("WEBVTT") def test_vtt_timestamp_format(self): """Test VTT timestamp format is correct.""" def format_vtt_time(ms): hours = ms // 3600000 minutes = (ms % 3600000) // 60000 seconds = (ms % 60000) // 1000 millis = ms % 1000 return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{millis:03d}" assert format_vtt_time(0) == "00:00:00.000" assert format_vtt_time(1500) == "00:00:01.500" assert format_vtt_time(3661500) == "01:01:01.500" def test_vtt_cue_format(self): """Test VTT cue block format.""" def create_vtt_cue(start_ms, end_ms, text): start = f"{start_ms // 3600000:02d}:{(start_ms % 3600000) // 60000:02d}:{(start_ms % 60000) // 1000:02d}.{start_ms % 1000:03d}" end = f"{end_ms // 3600000:02d}:{(end_ms % 3600000) // 60000:02d}:{(end_ms % 60000) // 1000:02d}.{end_ms % 1000:03d}" return f"{start} --> {end}\n{text}" cue = create_vtt_cue(0, 2500, "Test subtitle") assert "-->" in cue assert "Test subtitle" in cue class TestSRTExport: """Tests for SRT subtitle export format.""" def test_srt_timestamp_format(self): """Test SRT timestamp format (uses comma instead of period).""" def format_srt_time(ms): hours = ms // 3600000 minutes = (ms % 3600000) // 60000 seconds = (ms % 60000) // 1000 millis = ms % 1000 return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}" assert format_srt_time(0) == "00:00:00,000" assert format_srt_time(1500) == "00:00:01,500" def test_srt_entry_format(self): """Test SRT entry format with index.""" def create_srt_entry(index, start_ms, end_ms, text): def fmt(ms): return f"{ms // 3600000:02d}:{(ms % 3600000) // 60000:02d}:{(ms % 60000) // 1000:02d},{ms % 1000:03d}" return f"{index}\n{fmt(start_ms)} --> {fmt(end_ms)}\n{text}\n" entry = create_srt_entry(1, 0, 2500, "Test") lines = entry.strip().split("\n") assert lines[0] == "1" assert "-->" in lines[1] assert lines[2] == "Test" class TestJSONExport: """Tests for JSON export format.""" def test_json_structure(self): """Test JSON export structure.""" export = { "transcription_id": "abc123", "recording_id": "xyz789", "language": "de", "model": "large-v3", "created_at": datetime.utcnow().isoformat(), "duration_seconds": 300, "word_count": 500, "confidence_score": 0.92, "segments": [ { "id": 0, "start_ms": 0, "end_ms": 2500, "text": "Test segment", "speaker": "SPEAKER_00", "confidence": 0.95 } ] } # Verify structure assert "transcription_id" in export assert "segments" in export assert len(export["segments"]) > 0 # Verify serializable json_str = json.dumps(export) assert len(json_str) > 0 class TestMinIOStorage: """Tests for MinIO storage operations.""" def test_recording_path_format(self): """Test recording storage path format.""" recording_name = "test-meeting_20260115_120000" base_path = f"recordings/{recording_name}" video_path = f"{base_path}/video.mp4" audio_path = f"{base_path}/audio.wav" vtt_path = f"{base_path}/transcript.vtt" assert video_path == "recordings/test-meeting_20260115_120000/video.mp4" assert audio_path == "recordings/test-meeting_20260115_120000/audio.wav" assert vtt_path == "recordings/test-meeting_20260115_120000/transcript.vtt" def test_bucket_name_validation(self): """Test MinIO bucket name is valid.""" bucket = "breakpilot-recordings" # MinIO bucket name rules assert len(bucket) >= 3 assert len(bucket) <= 63 assert bucket.islower() or "-" in bucket assert not bucket.startswith("-") assert not bucket.endswith("-") class TestQueueProcessing: """Tests for RQ queue processing.""" def test_job_payload_structure(self): """Test transcription job payload structure.""" job_payload = { "transcription_id": "abc123", "recording_id": "xyz789", "audio_path": "recordings/test/audio.wav", "language": "de", "model": "large-v3", "priority": 0 } required_fields = ["transcription_id", "recording_id", "audio_path", "language", "model"] for field in required_fields: assert field in job_payload def test_job_status_transitions(self): """Test valid job status transitions.""" valid_transitions = { "pending": ["queued", "cancelled"], "queued": ["processing", "cancelled"], "processing": ["completed", "failed"], "completed": [], "failed": ["pending"], # Can retry "cancelled": [] } # Test a valid workflow status = "pending" assert "queued" in valid_transitions[status] status = "queued" assert "processing" in valid_transitions[status] status = "processing" assert "completed" in valid_transitions[status]