"""Stage 1: Audio transcription using faster-whisper with GPU acceleration.""" import json from dataclasses import dataclass from pathlib import Path from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn console = Console() @dataclass class TranscriptWord: word: str start: float end: float probability: float @dataclass class TranscriptSegment: id: int text: str start: float end: float words: list[TranscriptWord] @dataclass class Transcript: segments: list[TranscriptSegment] language: str language_probability: float duration: float @property def full_text(self) -> str: return " ".join(seg.text.strip() for seg in self.segments) def text_at(self, start: float, end: float) -> str: """Get transcript text within a time range.""" result = [] for seg in self.segments: if seg.end < start: continue if seg.start > end: break result.append(seg.text.strip()) return " ".join(result) def to_srt(self) -> str: """Export as SRT subtitle format.""" lines = [] for i, seg in enumerate(self.segments, 1): start = _format_srt_time(seg.start) end = _format_srt_time(seg.end) lines.append(f"{i}") lines.append(f"{start} --> {end}") lines.append(seg.text.strip()) lines.append("") return "\n".join(lines) def to_dict(self) -> dict: return { "language": self.language, "language_probability": self.language_probability, "duration": self.duration, "segments": [ { "id": seg.id, "text": seg.text, "start": seg.start, "end": seg.end, "words": [ { "word": w.word, "start": w.start, "end": w.end, "probability": w.probability, } for w in seg.words ], } for seg in self.segments ], } def save(self, output_dir: Path): output_dir.mkdir(parents=True, exist_ok=True) # JSON with full detail with open(output_dir / "transcript.json", "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, indent=2) # Plain text with open(output_dir / "transcript.txt", "w", encoding="utf-8") as f: f.write(self.full_text) # SRT subtitles with open(output_dir / "transcript.srt", "w", encoding="utf-8") as f: f.write(self.to_srt()) console.print(f"[green]Transcript saved to {output_dir}[/green]") def _format_srt_time(seconds: float) -> str: h = int(seconds // 3600) m = int((seconds % 3600) // 60) s = int(seconds % 60) ms = int((seconds % 1) * 1000) return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}" def transcribe(audio_path: str | Path, model_size: str = "large-v3", language: str = "en", device: str = "cuda", batch_size: int = 16) -> Transcript: """Transcribe an audio file using faster-whisper. Uses BatchedInferencePipeline + int8_float16 + VAD for archive/batch work. Word timestamps are skipped in batch mode (not needed for segment-level search). Pass batch_size=0 to fall back to sequential WhisperModel with word timestamps. """ from faster_whisper import WhisperModel, BatchedInferencePipeline audio_path = Path(audio_path) use_batched = batch_size > 0 console.print(f"[bold]Transcribing:[/bold] {audio_path.name}") console.print( f"[dim]Model: {model_size} | " f"{'batched x' + str(batch_size) + ' int8_float16' if use_batched else 'sequential float16'} | " f"Device: {device}[/dim]" ) if use_batched: base_model = WhisperModel(model_size, device=device, compute_type="int8_float16") model = BatchedInferencePipeline(model=base_model) segments_raw, info = model.transcribe( str(audio_path), language=language, batch_size=batch_size, ) else: model = WhisperModel(model_size, device=device, compute_type="float16") segments_raw, info = model.transcribe( str(audio_path), language=language, word_timestamps=True, vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200), ) console.print(f"[dim]Duration: {info.duration:.1f}s ({info.duration / 60:.1f} min)[/dim]") segments = [] for i, seg in enumerate(segments_raw): words = [] if not use_batched: words = [ TranscriptWord(word=w.word, start=w.start, end=w.end, probability=w.probability) for w in (seg.words or []) ] segments.append(TranscriptSegment( id=i, text=seg.text, start=seg.start, end=seg.end, words=words, )) if i % 50 == 0: console.print(f"[dim] {i} segments... ({seg.end:.0f}s)[/dim]") console.print(f"[green]Transcription complete: {len(segments)} segments[/green]") return Transcript( segments=segments, language=info.language, language_probability=info.language_probability, duration=info.duration, )