"""Voice profiler: builds and manages speaker embeddings using speechbrain. Uses ECAPA-TDNN speaker verification model to generate embeddings. No HuggingFace gated model access required (unlike pyannote). """ import json import subprocess from dataclasses import dataclass from pathlib import Path import numpy as np import torch import soundfile as sf from rich.console import Console from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn from rich.table import Table console = Console() # Target sample rate for the embedding model SAMPLE_RATE = 16000 # Minimum segment length for a usable embedding (seconds) MIN_SEGMENT_S = 3.0 # Maximum segment length to process at once (seconds) MAX_SEGMENT_S = 30.0 @dataclass class VoiceSegment: """A segment of audio attributed to a single speaker.""" start: float end: float embedding: np.ndarray | None = None speaker_label: str = "" @property def duration(self) -> float: return self.end - self.start @dataclass class SpeakerProfile: """A speaker's voice profile built from multiple embeddings.""" name: str role: str # "host", "cohost", "guest", "caller" embeddings: list[np.ndarray] source_episodes: list[str] composite_embedding: np.ndarray | None = None @property def num_samples(self) -> int: return len(self.embeddings) def compute_composite(self): """Average all embeddings into a single composite.""" if self.embeddings: self.composite_embedding = np.mean(self.embeddings, axis=0) # L2 normalize norm = np.linalg.norm(self.composite_embedding) if norm > 0: self.composite_embedding /= norm def similarity(self, embedding: np.ndarray) -> float: """Cosine similarity between an embedding and this profile's composite.""" if self.composite_embedding is None: self.compute_composite() return float(np.dot(self.composite_embedding, embedding) / ( np.linalg.norm(self.composite_embedding) * np.linalg.norm(embedding) + 1e-8 )) class VoiceProfiler: """Builds speaker voice profiles from audio using speechbrain ECAPA-TDNN.""" def __init__(self, profiles_dir: str | Path, device: str = "cuda"): self.profiles_dir = Path(profiles_dir) self.profiles_dir.mkdir(parents=True, exist_ok=True) self.device = device self._model = None self.profiles: dict[str, SpeakerProfile] = {} self._load_existing_profiles() def _get_model(self): """Lazy-load the embedding model (WavLM x-vector).""" if self._model is None: console.print("[dim]Loading speaker embedding model (WavLM-SV)...[/dim]") from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector self._extractor = Wav2Vec2FeatureExtractor.from_pretrained( "microsoft/wavlm-base-sv" ) self._model = WavLMForXVector.from_pretrained( "microsoft/wavlm-base-sv" ).to(self.device) self._model.eval() console.print("[dim]Speaker embedding model loaded[/dim]") return self._model def _load_existing_profiles(self): """Load saved profiles from disk.""" profile_file = self.profiles_dir / "profiles.json" if not profile_file.exists(): return with open(profile_file) as f: data = json.load(f) for name, pdata in data.items(): embeddings = [] emb_dir = self.profiles_dir / name.lower().replace(" ", "-") for emb_file in sorted(emb_dir.glob("embedding_*.npy")): embeddings.append(np.load(emb_file)) composite = None composite_file = emb_dir / "composite.npy" if composite_file.exists(): composite = np.load(composite_file) self.profiles[name] = SpeakerProfile( name=name, role=pdata.get("role", "unknown"), embeddings=embeddings, source_episodes=pdata.get("source_episodes", []), composite_embedding=composite, ) if self.profiles: console.print(f"[dim]Loaded {len(self.profiles)} voice profiles[/dim]") def save_profiles(self): """Save all profiles to disk.""" metadata = {} for name, profile in self.profiles.items(): slug = name.lower().replace(" ", "-") emb_dir = self.profiles_dir / slug emb_dir.mkdir(parents=True, exist_ok=True) # Save individual embeddings for i, emb in enumerate(profile.embeddings): np.save(emb_dir / f"embedding_{i:04d}.npy", emb) # Save composite profile.compute_composite() if profile.composite_embedding is not None: np.save(emb_dir / "composite.npy", profile.composite_embedding) metadata[name] = { "role": profile.role, "num_samples": profile.num_samples, "source_episodes": profile.source_episodes, } with open(self.profiles_dir / "profiles.json", "w") as f: json.dump(metadata, f, indent=2) console.print(f"[green]Saved {len(self.profiles)} voice profiles[/green]") def extract_embedding(self, audio_path: Path, start: float = 0.0, end: float | None = None) -> np.ndarray: """Extract a speaker embedding from an audio segment.""" model = self._get_model() # Load audio segment (already at SAMPLE_RATE via ffmpeg) waveform, sr = self._load_audio_segment(audio_path, start, end) # waveform is [1, samples] tensor, need just the numpy array for the extractor audio_np = waveform.squeeze(0).numpy() # Extract features inputs = self._extractor( audio_np, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True, ) # Get embedding with torch.no_grad(): outputs = model(**{k: v.to(self.device) for k, v in inputs.items()}) embedding = outputs.embeddings.squeeze().cpu().numpy() # L2 normalize norm = np.linalg.norm(embedding) if norm > 0: embedding = embedding / norm return embedding def _load_audio_segment(self, audio_path: Path, start: float = 0.0, end: float | None = None) -> tuple[torch.Tensor, int]: """Load an audio segment using ffmpeg (handles any format).""" cmd = ["ffmpeg", "-i", str(audio_path)] if start > 0: cmd.extend(["-ss", str(start)]) if end is not None: cmd.extend(["-t", str(end - start)]) cmd.extend(["-f", "wav", "-ac", "1", "-ar", str(SAMPLE_RATE), "-acodec", "pcm_s16le", "pipe:1"]) result = subprocess.run(cmd, capture_output=True, timeout=60) if result.returncode != 0: raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:200]}") import io data, sr = sf.read(io.BytesIO(result.stdout), dtype="float32") waveform = torch.from_numpy(data).unsqueeze(0) # [1, samples] return waveform, sr def bootstrap_host_from_episodes(self, episode_paths: list[Path], host_name: str = "Mike Swanson"): """Build host voice profile by extracting the dominant speaker from episodes. Strategy: In each episode, the host speaks the most. We extract embeddings from the first 2-5 minutes (usually the intro/monologue) where the host is most likely speaking solo. """ console.print(f"[bold]Bootstrapping voice profile for {host_name}[/bold]") console.print(f"[dim]Processing {len(episode_paths)} episodes[/dim]") if host_name not in self.profiles: self.profiles[host_name] = SpeakerProfile( name=host_name, role="host", embeddings=[], source_episodes=[], ) profile = self.profiles[host_name] with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("{task.completed}/{task.total}"), TimeElapsedColumn(), console=console, ) as progress: task = progress.add_task("Processing episodes...", total=len(episode_paths)) for ep_path in episode_paths: progress.update(task, description=f"Processing {ep_path.name}...") try: # Get episode duration duration = self._get_duration(ep_path) # Strategy: extract embeddings from multiple time windows # Skip first 30s (likely intro jingle), then sample every 2 min windows = [] # Window 1: After intro (30s-90s) — usually host monologue if duration > 90: windows.append((30.0, 90.0)) # Window 2: Early show (2min-3min) if duration > 180: windows.append((120.0, 180.0)) # Window 3: Mid show mid = duration / 2 if mid > 60: windows.append((mid, min(mid + 60, duration))) # Window 4: Late show (but not last 2 min — likely outro) late = duration - 180 if late > 300: windows.append((late, late + 60)) for start, end in windows: # Extract 10-second chunks within each window # and take the embedding of each chunk chunk_duration = 10.0 for chunk_start in np.arange(start, end - chunk_duration, chunk_duration): try: emb = self.extract_embedding( ep_path, chunk_start, chunk_start + chunk_duration ) profile.embeddings.append(emb) except Exception as e: console.print(f" [dim red]Chunk {chunk_start:.0f}s failed: {e}[/dim red]") continue profile.source_episodes.append(ep_path.name) except Exception as e: console.print(f" [red]Failed: {ep_path.name}: {e}[/red]") progress.update(task, advance=1) # Compute composite profile.compute_composite() console.print(f"\n[green]Host profile built: {profile.num_samples} embeddings " f"from {len(profile.source_episodes)} episodes[/green]") # Save self.save_profiles() def identify_speakers(self, audio_path: Path, window_s: float = 10.0, hop_s: float = 5.0, threshold: float = 0.70) -> list[VoiceSegment]: """Identify speakers throughout an audio file using sliding window. Returns timestamped segments with speaker labels and embeddings. """ console.print(f"[bold]Identifying speakers:[/bold] {audio_path.name}") duration = self._get_duration(audio_path) segments = [] with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TextColumn("{task.percentage:>3.0f}%"), TimeElapsedColumn(), console=console, ) as progress: task = progress.add_task("Analyzing speakers...", total=int(duration)) for start in np.arange(0, duration - window_s, hop_s): end = min(start + window_s, duration) try: emb = self.extract_embedding(audio_path, start, end) # Match against known profiles best_match = None best_score = 0.0 for name, profile in self.profiles.items(): score = profile.similarity(emb) if score > best_score: best_score = score best_match = name label = best_match if best_score >= threshold else "Unknown" if best_match and self.profiles[best_match].role == "host": label = f"Host: {best_match}" segments.append(VoiceSegment( start=start, end=end, embedding=emb, speaker_label=f"{label} ({best_score:.2f})", )) except Exception: segments.append(VoiceSegment( start=start, end=end, speaker_label="[error]", )) progress.update(task, completed=int(end)) # Print summary self._print_speaker_summary(segments, duration) return segments def _print_speaker_summary(self, segments: list[VoiceSegment], duration: float): """Print a summary of who spoke and for how long.""" speaker_times: dict[str, float] = {} for seg in segments: label = seg.speaker_label.split(" (")[0] # Strip score speaker_times[label] = speaker_times.get(label, 0) + seg.duration table = Table(title="Speaker Summary") table.add_column("Speaker", style="cyan") table.add_column("Time", style="magenta") table.add_column("Percentage", style="green") for speaker, time in sorted(speaker_times.items(), key=lambda x: -x[1]): pct = (time / duration) * 100 table.add_row(speaker, f"{time:.0f}s", f"{pct:.1f}%") console.print(table) def _get_duration(self, audio_path: Path) -> float: """Get audio duration in seconds.""" result = subprocess.run( ["ffprobe", "-v", "quiet", "-show_entries", "format=duration", "-of", "csv=p=0", str(audio_path)], capture_output=True, text=True, ) return float(result.stdout.strip()) def print_profiles(self): """Print summary of all loaded profiles.""" table = Table(title="Voice Profiles") table.add_column("Name", style="cyan") table.add_column("Role", style="green") table.add_column("Samples", style="magenta") table.add_column("Episodes", style="yellow") for name, profile in self.profiles.items(): table.add_row( name, profile.role, str(profile.num_samples), str(len(profile.source_episodes)), ) console.print(table)