"""Voice profiler: builds and manages speaker embeddings using speechbrain. Uses ECAPA-TDNN speaker verification model to generate embeddings. No HuggingFace gated model access required (unlike pyannote). """ import json import subprocess from dataclasses import dataclass from pathlib import Path import numpy as np import torch import soundfile as sf from rich.console import Console from rich.table import Table console = Console() # Target sample rate for the embedding model SAMPLE_RATE = 16000 # Minimum segment length for a usable embedding (seconds) MIN_SEGMENT_S = 3.0 # Maximum segment length to process at once (seconds) MAX_SEGMENT_S = 30.0 @dataclass class VoiceSegment: """A segment of audio attributed to a single speaker.""" start: float end: float embedding: np.ndarray | None = None speaker_label: str = "" @property def duration(self) -> float: return self.end - self.start @dataclass class SpeakerProfile: """A speaker's voice profile built from multiple embeddings.""" name: str role: str # "host", "cohost", "guest", "caller" embeddings: list[np.ndarray] source_episodes: list[str] composite_embedding: np.ndarray | None = None @property def num_samples(self) -> int: return len(self.embeddings) def compute_composite(self): """Average all embeddings into a single composite.""" if self.embeddings: self.composite_embedding = np.mean(self.embeddings, axis=0) # L2 normalize norm = np.linalg.norm(self.composite_embedding) if norm > 0: self.composite_embedding /= norm def similarity(self, embedding: np.ndarray) -> float: """Cosine similarity between an embedding and this profile's composite.""" if self.composite_embedding is None: self.compute_composite() return float(np.dot(self.composite_embedding, embedding) / ( np.linalg.norm(self.composite_embedding) * np.linalg.norm(embedding) + 1e-8 )) class VoiceProfiler: """Builds speaker voice profiles from audio using speechbrain ECAPA-TDNN.""" def __init__(self, profiles_dir: str | Path, device: str = "cuda"): self.profiles_dir = Path(profiles_dir) self.profiles_dir.mkdir(parents=True, exist_ok=True) self.device = device self._model = None self.profiles: dict[str, SpeakerProfile] = {} self._load_existing_profiles() def _get_model(self): """Lazy-load the embedding model (WavLM x-vector).""" if self._model is None: console.print("[dim]Loading speaker embedding model (WavLM-SV)...[/dim]") from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector self._extractor = Wav2Vec2FeatureExtractor.from_pretrained( "microsoft/wavlm-base-sv" ) self._model = WavLMForXVector.from_pretrained( "microsoft/wavlm-base-sv" ).to(self.device) self._model.eval() console.print("[dim]Speaker embedding model loaded[/dim]") return self._model def _load_existing_profiles(self): """Load saved profiles from disk.""" profile_file = self.profiles_dir / "profiles.json" if not profile_file.exists(): return with open(profile_file) as f: data = json.load(f) for name, pdata in data.items(): embeddings = [] emb_dir = self.profiles_dir / name.lower().replace(" ", "-") for emb_file in sorted(emb_dir.glob("embedding_*.npy")): embeddings.append(np.load(emb_file)) composite = None composite_file = emb_dir / "composite.npy" if composite_file.exists(): composite = np.load(composite_file) self.profiles[name] = SpeakerProfile( name=name, role=pdata.get("role", "unknown"), embeddings=embeddings, source_episodes=pdata.get("source_episodes", []), composite_embedding=composite, ) if self.profiles: console.print(f"[dim]Loaded {len(self.profiles)} voice profiles[/dim]") def save_profiles(self): """Save all profiles to disk.""" metadata = {} for name, profile in self.profiles.items(): slug = name.lower().replace(" ", "-") emb_dir = self.profiles_dir / slug emb_dir.mkdir(parents=True, exist_ok=True) # Save individual embeddings for i, emb in enumerate(profile.embeddings): np.save(emb_dir / f"embedding_{i:04d}.npy", emb) # Save composite profile.compute_composite() if profile.composite_embedding is not None: np.save(emb_dir / "composite.npy", profile.composite_embedding) metadata[name] = { "role": profile.role, "num_samples": profile.num_samples, "source_episodes": profile.source_episodes, } with open(self.profiles_dir / "profiles.json", "w") as f: json.dump(metadata, f, indent=2) console.print(f"[green]Saved {len(self.profiles)} voice profiles[/green]") def extract_embedding(self, audio_path: Path, start: float = 0.0, end: float | None = None) -> np.ndarray: """Extract a speaker embedding from an audio segment (file-based, any format).""" self._get_model() waveform, _ = self._load_audio_segment(audio_path, start, end) return self._embed_audio_np(waveform.squeeze(0).numpy()) def _embed_audio_np(self, audio_np: np.ndarray) -> np.ndarray: """Embed a float32 mono numpy array (already at SAMPLE_RATE). Returns L2-normalized embedding.""" self._get_model() inputs = self._extractor( audio_np, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True, ) with torch.no_grad(): outputs = self._model(**{k: v.to(self.device) for k, v in inputs.items()}) embedding = outputs.embeddings.squeeze().cpu().numpy() norm = np.linalg.norm(embedding) if norm > 0: embedding = embedding / norm return embedding def _load_full_audio(self, audio_path: Path) -> np.ndarray: """Decode entire audio file to float32 mono at SAMPLE_RATE via a single ffmpeg call.""" cmd = [ "ffmpeg", "-i", str(audio_path), "-f", "wav", "-ac", "1", "-ar", str(SAMPLE_RATE), "-acodec", "pcm_s16le", "pipe:1", ] result = subprocess.run(cmd, capture_output=True, timeout=600) if result.returncode != 0: raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:200]}") import io data, _ = sf.read(io.BytesIO(result.stdout), dtype="float32") return data # shape: (samples,) def _load_audio_segment(self, audio_path: Path, start: float = 0.0, end: float | None = None) -> tuple[torch.Tensor, int]: """Load a single audio segment via ffmpeg (used for one-off extraction).""" cmd = ["ffmpeg", "-i", str(audio_path)] if start > 0: cmd.extend(["-ss", str(start)]) if end is not None: cmd.extend(["-t", str(end - start)]) cmd.extend(["-f", "wav", "-ac", "1", "-ar", str(SAMPLE_RATE), "-acodec", "pcm_s16le", "pipe:1"]) result = subprocess.run(cmd, capture_output=True, timeout=60) if result.returncode != 0: raise RuntimeError(f"ffmpeg failed: {result.stderr.decode()[:200]}") import io data, sr = sf.read(io.BytesIO(result.stdout), dtype="float32") waveform = torch.from_numpy(data).unsqueeze(0) # [1, samples] return waveform, sr def bootstrap_host_from_episodes(self, episode_paths: list[Path], host_name: str = "Mike Swanson"): """Build host voice profile by extracting the dominant speaker from episodes. Strategy: In each episode, the host speaks the most. We extract embeddings from the first 2-5 minutes (usually the intro/monologue) where the host is most likely speaking solo. """ console.print(f"[bold]Bootstrapping voice profile for {host_name}[/bold]") console.print(f"[dim]Processing {len(episode_paths)} episodes[/dim]") if host_name not in self.profiles: self.profiles[host_name] = SpeakerProfile( name=host_name, role="host", embeddings=[], source_episodes=[], ) profile = self.profiles[host_name] for ep_idx, ep_path in enumerate(episode_paths, 1): console.print(f"[dim] [{ep_idx}/{len(episode_paths)}] {ep_path.name}[/dim]") try: duration = self._get_duration(ep_path) windows = [] if duration > 90: windows.append((30.0, 90.0)) if duration > 180: windows.append((120.0, 180.0)) mid = duration / 2 if mid > 60: windows.append((mid, min(mid + 60, duration))) late = duration - 180 if late > 300: windows.append((late, late + 60)) chunk_duration = 10.0 for start, end in windows: for chunk_start in np.arange(start, end - chunk_duration, chunk_duration): try: emb = self.extract_embedding( ep_path, chunk_start, chunk_start + chunk_duration ) profile.embeddings.append(emb) except Exception as e: console.print(f" [dim red]Chunk {chunk_start:.0f}s failed: {e}[/dim red]") profile.source_episodes.append(ep_path.name) except Exception as e: console.print(f" [red]Failed: {ep_path.name}: {e}[/red]") # Compute composite profile.compute_composite() console.print(f"\n[green]Host profile built: {profile.num_samples} embeddings " f"from {len(profile.source_episodes)} episodes[/green]") # Save self.save_profiles() def identify_speakers(self, audio_path: Path, window_s: float = 10.0, hop_s: float = 5.0, threshold: float = 0.70, skip_ranges: list[tuple[float, float]] | None = None ) -> list[VoiceSegment]: """Identify speakers throughout an audio file using sliding window. Loads the full audio once then slices in memory — avoids spawning hundreds of ffmpeg subprocesses. Returns timestamped segments with speaker labels and embeddings. skip_ranges: list of (start, end) seconds. Windows whose midpoint falls inside any of these ranges are labeled "[bumper]" and the speaker cosine match is skipped — used to suppress music/promo from being matched against speaker profiles. """ console.print(f"[bold]Identifying speakers:[/bold] {audio_path.name}") duration = self._get_duration(audio_path) console.print(f"[dim]Loading audio into memory...[/dim]") audio = self._load_full_audio(audio_path) # float32 mono array self._get_model() # ensure model is warm before the loop skip_ranges = skip_ranges or [] segments = [] window_samples = int(window_s * SAMPLE_RATE) hop_samples = int(hop_s * SAMPLE_RATE) total_samples = len(audio) total_windows = int((duration - window_s) / hop_s) + 1 report_every = max(1, total_windows // 10) for idx, start in enumerate(np.arange(0, duration - window_s, hop_s)): end = min(start + window_s, duration) s = int(start * SAMPLE_RATE) e = min(s + window_samples, total_samples) mid = (start + end) / 2 in_bumper = any(rs <= mid <= re for rs, re in skip_ranges) if in_bumper: segments.append(VoiceSegment( start=start, end=end, speaker_label="[bumper] (1.00)", )) continue try: emb = self._embed_audio_np(audio[s:e]) best_match = None best_score = 0.0 for name, profile in self.profiles.items(): score = profile.similarity(emb) if score > best_score: best_score = score best_match = name if best_score >= threshold: role = self.profiles[best_match].role if best_match else "unknown" if role == "host": label = f"Host: {best_match}" elif role == "cohost": label = f"Cohost: {best_match}" else: label = best_match else: label = "Unknown" segments.append(VoiceSegment( start=start, end=end, embedding=emb, speaker_label=f"{label} ({best_score:.2f})", )) except Exception: segments.append(VoiceSegment( start=start, end=end, speaker_label="[error]", )) if idx % report_every == 0: pct = int(end / duration * 100) console.print(f"[dim] {pct}% ({end:.0f}s / {duration:.0f}s)[/dim]") # Print summary self._print_speaker_summary(segments, duration) return segments def _print_speaker_summary(self, segments: list[VoiceSegment], duration: float): """Print a summary of who spoke and for how long.""" speaker_times: dict[str, float] = {} for seg in segments: label = seg.speaker_label.split(" (")[0] # Strip score speaker_times[label] = speaker_times.get(label, 0) + seg.duration table = Table(title="Speaker Summary") table.add_column("Speaker", style="cyan") table.add_column("Time", style="magenta") table.add_column("Percentage", style="green") for speaker, time in sorted(speaker_times.items(), key=lambda x: -x[1]): pct = (time / duration) * 100 table.add_row(speaker, f"{time:.0f}s", f"{pct:.1f}%") console.print(table) def _get_duration(self, audio_path: Path) -> float: """Get audio duration in seconds.""" result = subprocess.run( ["ffprobe", "-v", "quiet", "-show_entries", "format=duration", "-of", "csv=p=0", str(audio_path)], capture_output=True, text=True, ) return float(result.stdout.strip()) def print_profiles(self): """Print summary of all loaded profiles.""" table = Table(title="Voice Profiles") table.add_column("Name", style="cyan") table.add_column("Role", style="green") table.add_column("Samples", style="magenta") table.add_column("Episodes", style="yellow") for name, profile in self.profiles.items(): table.add_row( name, profile.role, str(profile.num_samples), str(len(profile.source_episodes)), ) console.print(table)