"""Stage 2: Speaker diarization using pyannote.audio with voice profile matching.""" import json from dataclasses import dataclass from pathlib import Path import numpy as np from rich.console import Console console = Console() @dataclass class SpeakerTurn: speaker: str # "SPEAKER_00", "Host: Mike Swanson", "Caller 1", etc. start: float end: float confidence: float = 1.0 @property def duration(self) -> float: return self.end - self.start @dataclass class DiarizationResult: turns: list[SpeakerTurn] num_speakers: int speaker_map: dict[str, str] # raw label -> friendly name def speaker_at(self, time: float) -> str | None: """Get the speaker at a given timestamp.""" for turn in self.turns: if turn.start <= time <= turn.end: return turn.speaker return None def speaker_time(self, speaker: str) -> float: """Total speaking time for a speaker.""" return sum(t.duration for t in self.turns if t.speaker == speaker) def speakers_ranked(self) -> list[tuple[str, float]]: """Speakers ranked by total speaking time.""" times = {} for turn in self.turns: times[turn.speaker] = times.get(turn.speaker, 0) + turn.duration return sorted(times.items(), key=lambda x: x[1], reverse=True) def to_dict(self) -> dict: return { "num_speakers": self.num_speakers, "speaker_map": self.speaker_map, "turns": [ { "speaker": t.speaker, "start": t.start, "end": t.end, "confidence": t.confidence, } for t in self.turns ], } def save(self, output_dir: Path): output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / "diarization.json", "w") as f: json.dump(self.to_dict(), f, indent=2) console.print(f"[green]Diarization saved to {output_dir}[/green]") class VoiceProfileStore: """Manages speaker voice embeddings for identification.""" def __init__(self, profiles_dir: str | Path): self.profiles_dir = Path(profiles_dir) self.embeddings: dict[str, np.ndarray] = {} self.metadata: dict[str, dict] = {} self._load_profiles() def _load_profiles(self): if not self.profiles_dir.exists(): return for npy_file in self.profiles_dir.rglob("*.npy"): name = npy_file.stem # Determine speaker name from directory structure parent = npy_file.parent.name if parent.startswith("host-"): speaker_name = parent.replace("host-", "").replace("-", " ").title() role = "host" elif parent == "guests": speaker_name = name.replace("-", " ").title() role = "guest" elif parent == "callers": speaker_name = name role = "caller" else: speaker_name = name role = "unknown" self.embeddings[name] = np.load(npy_file) self.metadata[name] = { "name": speaker_name, "role": role, "file": str(npy_file), } if self.embeddings: console.print(f"[dim]Loaded {len(self.embeddings)} voice profiles[/dim]") def match_embedding(self, embedding: np.ndarray, threshold: float = 0.75 ) -> tuple[str | None, float]: """Match an embedding against stored profiles. Returns (name, similarity).""" if not self.embeddings: return None, 0.0 best_match = None best_score = 0.0 for name, stored in self.embeddings.items(): # Cosine similarity similarity = np.dot(embedding, stored) / ( np.linalg.norm(embedding) * np.linalg.norm(stored) + 1e-8 ) if similarity > best_score: best_score = similarity best_match = name if best_score >= threshold: meta = self.metadata.get(best_match, {}) friendly_name = meta.get("name", best_match) role = meta.get("role", "unknown") if role == "host": return f"Host: {friendly_name}", best_score return friendly_name, best_score return None, best_score def save_embedding(self, name: str, embedding: np.ndarray, role: str = "unknown"): """Save a new voice profile.""" if role == "host": subdir = self.profiles_dir / f"host-{name.lower().replace(' ', '-')}" elif role == "guest": subdir = self.profiles_dir / "guests" elif role == "caller": subdir = self.profiles_dir / "callers" else: subdir = self.profiles_dir / "unknown" subdir.mkdir(parents=True, exist_ok=True) filename = name.lower().replace(" ", "-") np.save(subdir / f"{filename}.npy", embedding) console.print(f"[green]Saved voice profile: {name} ({role})[/green]") def diarize(audio_path: str | Path, voice_profiles: VoiceProfileStore | None = None, min_speakers: int = 1, max_speakers: int = 6, host_match_threshold: float = 0.85, transcript_path: str | Path | None = None) -> DiarizationResult: """Run speaker diarization using WavLM sliding-window speaker identification. Uses the built-in VoiceProfiler (WavLM x-vectors) — no HuggingFace token or gated model required. Identifies HOST vs non-HOST speakers using the stored voice profile for Mike Swanson. If transcript_path is provided, time ranges containing show promo/bumper text are pre-marked and skipped at speaker-identification time so vocal music doesn't match cohost profiles. """ import torch from .voice_profiler import VoiceProfiler audio_path = Path(audio_path) console.print(f"[bold]Diarizing:[/bold] {audio_path.name}") device = "cuda" if torch.cuda.is_available() else "cpu" console.print(f"[dim]Device: {device}[/dim]") # Locate voice profiles directory from the VoiceProfileStore path profiles_dir = voice_profiles.profiles_dir if voice_profiles else Path("voice-profiles") profiler = VoiceProfiler(profiles_dir, device=device) if not profiler.profiles: console.print("[yellow]No voice profiles found — labeling all as HOST[/yellow]") # Return a single HOST turn covering the whole episode from .voice_profiler import VoiceProfiler as VP duration = profiler._get_duration(audio_path) return DiarizationResult( turns=[SpeakerTurn(speaker="HOST", start=0.0, end=duration)], num_speakers=1, speaker_map={"HOST": "HOST"}, ) # Pre-compute bumper / promo time ranges from transcript if available bumper_ranges: list[tuple[float, float]] = [] if transcript_path is not None: transcript_path = Path(transcript_path) if transcript_path.exists(): from .qa_extractor import _is_promo_or_bumper with open(transcript_path) as f: tdata = json.load(f) for seg in tdata.get("segments", []): if _is_promo_or_bumper(seg.get("text", "")): bumper_ranges.append((seg["start"], seg["end"])) if bumper_ranges: console.print( f"[dim]Bumper filter: {len(bumper_ranges)} promo/bumper " f"transcript segments will be skipped during speaker match[/dim]" ) # Sliding-window identification: 10s windows, 5s hop voice_segs = profiler.identify_speakers( audio_path, window_s=10.0, hop_s=5.0, threshold=host_match_threshold, skip_ranges=bumper_ranges, ) # Convert VoiceSegment labels to HOST / CALLER raw_turns = [] for seg in voice_segs: label = seg.speaker_label.split(" (")[0] # strip confidence score if label.startswith("Host:") or label.startswith("Host "): speaker = "HOST" elif label.startswith("Cohost:"): speaker = "CO-HOST" elif label == "[bumper]": speaker = "BUMPER" elif label == "[error]": speaker = "UNKNOWN" else: speaker = "CALLER" raw_turns.append(SpeakerTurn( speaker=speaker, start=seg.start, end=seg.end, confidence=float(seg.speaker_label.split("(")[-1].rstrip(")")) if "(" in seg.speaker_label else 0.5, )) # Merge consecutive same-speaker turns merged: list[SpeakerTurn] = [] for turn in raw_turns: if merged and merged[-1].speaker == turn.speaker: merged[-1].end = turn.end else: merged.append(SpeakerTurn( speaker=turn.speaker, start=turn.start, end=turn.end, confidence=turn.confidence, )) unique_speakers = set(t.speaker for t in merged) speaker_map = {s: s for s in unique_speakers} host_time = sum(t.duration for t in merged if t.speaker == "HOST") caller_time = sum(t.duration for t in merged if t.speaker == "CALLER") console.print(f"[green]Diarization complete:[/green] {len(merged)} turns | " f"HOST {host_time:.0f}s / CALLER {caller_time:.0f}s") return DiarizationResult( turns=merged, num_speakers=len(unique_speakers), speaker_map=speaker_map, )