""" BreakPilot Transcript Aligner Aligns Whisper transcription segments with pyannote speaker diarization. Assigns speaker IDs to each transcribed segment. """ import structlog from typing import List, Dict, Optional from collections import defaultdict log = structlog.get_logger(__name__) class TranscriptAligner: """ Aligns transcription segments with speaker diarization results. Uses overlap-based matching to assign speaker IDs to each transcribed segment. Handles cases where speakers change mid-sentence. """ def __init__(self): """Initialize the aligner.""" self._speaker_count = 0 self._speaker_map = {} # Maps pyannote IDs to friendly names def align( self, transcription_segments: List[Dict], diarization_segments: List[Dict], min_overlap_ratio: float = 0.3 ) -> List[Dict]: """ Align transcription with speaker diarization. Args: transcription_segments: List of segments from Whisper diarization_segments: List of segments from pyannote min_overlap_ratio: Minimum overlap ratio to assign speaker Returns: Transcription segments with speaker_id added """ if not diarization_segments: log.warning("no_diarization_segments", message="Returning transcription without speakers") return transcription_segments log.info( "aligning_transcription", transcription_count=len(transcription_segments), diarization_count=len(diarization_segments) ) # Build speaker mapping unique_speakers = set(s["speaker_id"] for s in diarization_segments) self._speaker_count = len(unique_speakers) for i, speaker in enumerate(sorted(unique_speakers)): self._speaker_map[speaker] = f"SPEAKER_{i:02d}" # Align each transcription segment aligned_segments = [] for trans_seg in transcription_segments: speaker_id = self._find_speaker_for_segment( trans_seg, diarization_segments, min_overlap_ratio ) aligned_seg = trans_seg.copy() aligned_seg["speaker_id"] = speaker_id aligned_segments.append(aligned_seg) # Log statistics speaker_counts = defaultdict(int) for seg in aligned_segments: speaker_counts[seg.get("speaker_id", "UNKNOWN")] += 1 log.info( "alignment_complete", speakers=dict(speaker_counts), total_speakers=self._speaker_count ) return aligned_segments def _find_speaker_for_segment( self, trans_seg: Dict, diarization_segments: List[Dict], min_overlap_ratio: float ) -> Optional[str]: """ Find the best matching speaker for a transcription segment. Uses overlap-based matching with the speaker who has the highest overlap with the segment. """ trans_start = trans_seg["start_time_ms"] trans_end = trans_seg["end_time_ms"] trans_duration = trans_end - trans_start if trans_duration <= 0: return None # Find overlapping diarization segments overlaps = [] for diar_seg in diarization_segments: diar_start = diar_seg["start_time_ms"] diar_end = diar_seg["end_time_ms"] # Calculate overlap overlap_start = max(trans_start, diar_start) overlap_end = min(trans_end, diar_end) overlap_duration = max(0, overlap_end - overlap_start) if overlap_duration > 0: overlap_ratio = overlap_duration / trans_duration overlaps.append({ "speaker_id": diar_seg["speaker_id"], "overlap_duration": overlap_duration, "overlap_ratio": overlap_ratio }) if not overlaps: return None # Find speaker with highest overlap best_match = max(overlaps, key=lambda x: x["overlap_duration"]) if best_match["overlap_ratio"] >= min_overlap_ratio: original_id = best_match["speaker_id"] return self._speaker_map.get(original_id, original_id) return None def get_speaker_count(self) -> int: """Get the number of unique speakers detected.""" return self._speaker_count def get_speaker_mapping(self) -> Dict[str, str]: """Get the mapping from pyannote IDs to friendly names.""" return self._speaker_map.copy() def merge_consecutive_segments( self, segments: List[Dict], max_gap_ms: int = 1000, same_speaker_only: bool = True ) -> List[Dict]: """ Merge consecutive segments that are close together. Useful for creating cleaner subtitle output. Args: segments: List of aligned segments max_gap_ms: Maximum gap between segments to merge same_speaker_only: Only merge if same speaker Returns: List of merged segments """ if not segments: return [] merged = [] current = segments[0].copy() for next_seg in segments[1:]: gap = next_seg["start_time_ms"] - current["end_time_ms"] same_speaker = ( not same_speaker_only or current.get("speaker_id") == next_seg.get("speaker_id") ) if gap <= max_gap_ms and same_speaker: # Merge segments current["end_time_ms"] = next_seg["end_time_ms"] current["text"] = current["text"] + " " + next_seg["text"] # Merge word timestamps if present if "words" in current and "words" in next_seg: current["words"].extend(next_seg["words"]) else: # Save current and start new merged.append(current) current = next_seg.copy() # Don't forget the last segment merged.append(current) log.info( "segments_merged", original_count=len(segments), merged_count=len(merged) ) return merged