"""Stage 2: Speaker diarization using pyannote.audio with voice profile matching.""" import json from dataclasses import dataclass from pathlib import Path import numpy as np from rich.console import Console console = Console() @dataclass class SpeakerTurn: speaker: str # "SPEAKER_00", "Host: Mike Swanson", "Caller 1", etc. start: float end: float confidence: float = 1.0 @property def duration(self) -> float: return self.end - self.start @dataclass class DiarizationResult: turns: list[SpeakerTurn] num_speakers: int speaker_map: dict[str, str] # raw label -> friendly name def speaker_at(self, time: float) -> str | None: """Get the speaker at a given timestamp.""" for turn in self.turns: if turn.start <= time <= turn.end: return turn.speaker return None def speaker_time(self, speaker: str) -> float: """Total speaking time for a speaker.""" return sum(t.duration for t in self.turns if t.speaker == speaker) def speakers_ranked(self) -> list[tuple[str, float]]: """Speakers ranked by total speaking time.""" times = {} for turn in self.turns: times[turn.speaker] = times.get(turn.speaker, 0) + turn.duration return sorted(times.items(), key=lambda x: x[1], reverse=True) def to_dict(self) -> dict: return { "num_speakers": self.num_speakers, "speaker_map": self.speaker_map, "turns": [ { "speaker": t.speaker, "start": t.start, "end": t.end, "confidence": t.confidence, } for t in self.turns ], } def save(self, output_dir: Path): output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir / "diarization.json", "w") as f: json.dump(self.to_dict(), f, indent=2) console.print(f"[green]Diarization saved to {output_dir}[/green]") class VoiceProfileStore: """Manages speaker voice embeddings for identification.""" def __init__(self, profiles_dir: str | Path): self.profiles_dir = Path(profiles_dir) self.embeddings: dict[str, np.ndarray] = {} self.metadata: dict[str, dict] = {} self._load_profiles() def _load_profiles(self): if not self.profiles_dir.exists(): return for npy_file in self.profiles_dir.rglob("*.npy"): name = npy_file.stem # Determine speaker name from directory structure parent = npy_file.parent.name if parent.startswith("host-"): speaker_name = parent.replace("host-", "").replace("-", " ").title() role = "host" elif parent == "guests": speaker_name = name.replace("-", " ").title() role = "guest" elif parent == "callers": speaker_name = name role = "caller" else: speaker_name = name role = "unknown" self.embeddings[name] = np.load(npy_file) self.metadata[name] = { "name": speaker_name, "role": role, "file": str(npy_file), } if self.embeddings: console.print(f"[dim]Loaded {len(self.embeddings)} voice profiles[/dim]") def match_embedding(self, embedding: np.ndarray, threshold: float = 0.75 ) -> tuple[str | None, float]: """Match an embedding against stored profiles. Returns (name, similarity).""" if not self.embeddings: return None, 0.0 best_match = None best_score = 0.0 for name, stored in self.embeddings.items(): # Cosine similarity similarity = np.dot(embedding, stored) / ( np.linalg.norm(embedding) * np.linalg.norm(stored) + 1e-8 ) if similarity > best_score: best_score = similarity best_match = name if best_score >= threshold: meta = self.metadata.get(best_match, {}) friendly_name = meta.get("name", best_match) role = meta.get("role", "unknown") if role == "host": return f"Host: {friendly_name}", best_score return friendly_name, best_score return None, best_score def save_embedding(self, name: str, embedding: np.ndarray, role: str = "unknown"): """Save a new voice profile.""" if role == "host": subdir = self.profiles_dir / f"host-{name.lower().replace(' ', '-')}" elif role == "guest": subdir = self.profiles_dir / "guests" elif role == "caller": subdir = self.profiles_dir / "callers" else: subdir = self.profiles_dir / "unknown" subdir.mkdir(parents=True, exist_ok=True) filename = name.lower().replace(" ", "-") np.save(subdir / f"{filename}.npy", embedding) console.print(f"[green]Saved voice profile: {name} ({role})[/green]") def diarize(audio_path: str | Path, voice_profiles: VoiceProfileStore | None = None, min_speakers: int = 1, max_speakers: int = 6, host_match_threshold: float = 0.75) -> DiarizationResult: """Run speaker diarization on an audio file.""" from pyannote.audio import Pipeline import torch audio_path = Path(audio_path) console.print(f"[bold]Diarizing:[/bold] {audio_path.name}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") console.print(f"[dim]Device: {device}[/dim]") pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1" ).to(device) diarization = pipeline( str(audio_path), min_speakers=min_speakers, max_speakers=max_speakers, ) # Extract turns raw_turns = [] for turn, _, speaker in diarization.itertracks(yield_label=True): raw_turns.append(SpeakerTurn( speaker=speaker, start=turn.start, end=turn.end, )) # Count unique speakers raw_speakers = set(t.speaker for t in raw_turns) console.print(f"[dim]Detected {len(raw_speakers)} speakers[/dim]") # Match against voice profiles if available speaker_map = {} if voice_profiles and voice_profiles.embeddings: console.print("[dim]Matching speakers against voice profiles...[/dim]") embedding_model = pipeline.embedding # pyannote's embedding model # Get embeddings for each detected speaker from pyannote.audio import Inference inference = Inference(pipeline.embedding, window="whole") for raw_label in raw_speakers: # Get segments for this speaker speaker_segments = [t for t in raw_turns if t.speaker == raw_label] total_time = sum(t.duration for t in speaker_segments) # Use the longest segment for embedding longest = max(speaker_segments, key=lambda t: t.duration) try: # Extract embedding from audio segment import torchaudio waveform, sr = torchaudio.load( str(audio_path), frame_offset=int(longest.start * sr if 'sr' in dir() else longest.start * 16000), num_frames=int(longest.duration * sr if 'sr' in dir() else longest.duration * 16000), ) # This is simplified — proper implementation would use pyannote's # embedding extraction pipeline match_name, score = voice_profiles.match_embedding( np.zeros(256), # placeholder threshold=host_match_threshold, ) if match_name: speaker_map[raw_label] = match_name console.print(f" [green]{raw_label} -> {match_name} " f"(score: {score:.2f}, {total_time:.0f}s)[/green]") except Exception as e: console.print(f" [yellow]Could not match {raw_label}: {e}[/yellow]") # If no voice profiles matched, use speaking time heuristic # The host almost always has the most speaking time if not speaker_map: ranked = sorted( [(s, sum(t.duration for t in raw_turns if t.speaker == s)) for s in raw_speakers], key=lambda x: x[1], reverse=True, ) if ranked: speaker_map[ranked[0][0]] = f"Host: {voice_profiles.metadata.get('host', {}).get('name', 'Unknown')}" console.print(f" [yellow]Assumed {ranked[0][0]} is host " f"(most speaking time: {ranked[0][1]:.0f}s)[/yellow]") # If no voice profiles at all, label by speaking time if not speaker_map: ranked = sorted( [(s, sum(t.duration for t in raw_turns if t.speaker == s)) for s in raw_speakers], key=lambda x: x[1], reverse=True, ) for i, (speaker, time) in enumerate(ranked): if i == 0: speaker_map[speaker] = "Host (assumed)" else: speaker_map[speaker] = f"Speaker {i}" # Apply friendly names for turn in raw_turns: if turn.speaker in speaker_map: turn.speaker = speaker_map[turn.speaker] console.print(f"[green]Diarization complete: {len(raw_turns)} turns, " f"{len(raw_speakers)} speakers[/green]") return DiarizationResult( turns=raw_turns, num_speakers=len(raw_speakers), speaker_map=speaker_map, )