Adds a transcript-driven bumper filter to the diarization pipeline. When
a transcript segment matches qa_extractor's promo/bumper signatures, the
overlapping audio windows are labeled BUMPER and the WavLM cosine match
is skipped. Prevents music/promo from being matched against speaker
profiles (the failure mode Mike caught in 2018-s10e18 @ 09:20-10:05).
Code changes:
- src/voice_profiler.py: identify_speakers() takes optional skip_ranges
parameter; windows whose midpoint falls in a skip range get labeled
"[bumper]" and skip cosine match
- src/diarizer.py: diarize() takes optional transcript_path; pre-computes
bumper time ranges via qa_extractor._is_promo_or_bumper, passes to
identify_speakers; adds BUMPER speaker label
- benchmark.py: passes transcript_path to diarize()
Aggregate impact across 9-episode test set:
Tara attribution: 4880s -> 3680s (-1200s / -25%)
Q&A pairs: 17 -> 19 (+2)
(bumper-flagged segments had been disrupting conversation detection
in 2017-s9e30 and 2018-s10e18)
CALLER total: 1320s -> 1190s (bumpers previously labeled CALLER moved)
Per-episode bumpers caught: 1-8, total ~165 bumper segments across set
Remaining Tara false positives are real callers acoustically similar to
Tara (Christopher in 2018, Kay in 2012, William and Charles in 2015) and
guest Clay in 2015-s7e19 — those need profile rebuild + Clay profile,
not bumper filtering.
Adds download_full_archive.py — resumable mirror-style downloader that
walks IX server's /home/gurushow/public_html/archive/{year}/ and copies
all MP3s to archive-data/episodes/. Run is in progress (~589 files,
~10-15GB). Used to source clean profile windows for the remaining
co-hosts (Tara rebuild, Clay, Tony, Rob, Randall, producers).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
271 lines
9.5 KiB
Python
271 lines
9.5 KiB
Python
"""Stage 2: Speaker diarization using pyannote.audio with voice profile matching."""
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
from rich.console import Console
|
|
|
|
console = Console()
|
|
|
|
|
|
@dataclass
|
|
class SpeakerTurn:
|
|
speaker: str # "SPEAKER_00", "Host: Mike Swanson", "Caller 1", etc.
|
|
start: float
|
|
end: float
|
|
confidence: float = 1.0
|
|
|
|
@property
|
|
def duration(self) -> float:
|
|
return self.end - self.start
|
|
|
|
|
|
@dataclass
|
|
class DiarizationResult:
|
|
turns: list[SpeakerTurn]
|
|
num_speakers: int
|
|
speaker_map: dict[str, str] # raw label -> friendly name
|
|
|
|
def speaker_at(self, time: float) -> str | None:
|
|
"""Get the speaker at a given timestamp."""
|
|
for turn in self.turns:
|
|
if turn.start <= time <= turn.end:
|
|
return turn.speaker
|
|
return None
|
|
|
|
def speaker_time(self, speaker: str) -> float:
|
|
"""Total speaking time for a speaker."""
|
|
return sum(t.duration for t in self.turns if t.speaker == speaker)
|
|
|
|
def speakers_ranked(self) -> list[tuple[str, float]]:
|
|
"""Speakers ranked by total speaking time."""
|
|
times = {}
|
|
for turn in self.turns:
|
|
times[turn.speaker] = times.get(turn.speaker, 0) + turn.duration
|
|
return sorted(times.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"num_speakers": self.num_speakers,
|
|
"speaker_map": self.speaker_map,
|
|
"turns": [
|
|
{
|
|
"speaker": t.speaker,
|
|
"start": t.start,
|
|
"end": t.end,
|
|
"confidence": t.confidence,
|
|
}
|
|
for t in self.turns
|
|
],
|
|
}
|
|
|
|
def save(self, output_dir: Path):
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(output_dir / "diarization.json", "w") as f:
|
|
json.dump(self.to_dict(), f, indent=2)
|
|
console.print(f"[green]Diarization saved to {output_dir}[/green]")
|
|
|
|
|
|
class VoiceProfileStore:
|
|
"""Manages speaker voice embeddings for identification."""
|
|
|
|
def __init__(self, profiles_dir: str | Path):
|
|
self.profiles_dir = Path(profiles_dir)
|
|
self.embeddings: dict[str, np.ndarray] = {}
|
|
self.metadata: dict[str, dict] = {}
|
|
self._load_profiles()
|
|
|
|
def _load_profiles(self):
|
|
if not self.profiles_dir.exists():
|
|
return
|
|
|
|
for npy_file in self.profiles_dir.rglob("*.npy"):
|
|
name = npy_file.stem
|
|
# Determine speaker name from directory structure
|
|
parent = npy_file.parent.name
|
|
if parent.startswith("host-"):
|
|
speaker_name = parent.replace("host-", "").replace("-", " ").title()
|
|
role = "host"
|
|
elif parent == "guests":
|
|
speaker_name = name.replace("-", " ").title()
|
|
role = "guest"
|
|
elif parent == "callers":
|
|
speaker_name = name
|
|
role = "caller"
|
|
else:
|
|
speaker_name = name
|
|
role = "unknown"
|
|
|
|
self.embeddings[name] = np.load(npy_file)
|
|
self.metadata[name] = {
|
|
"name": speaker_name,
|
|
"role": role,
|
|
"file": str(npy_file),
|
|
}
|
|
|
|
if self.embeddings:
|
|
console.print(f"[dim]Loaded {len(self.embeddings)} voice profiles[/dim]")
|
|
|
|
def match_embedding(self, embedding: np.ndarray, threshold: float = 0.75
|
|
) -> tuple[str | None, float]:
|
|
"""Match an embedding against stored profiles. Returns (name, similarity)."""
|
|
if not self.embeddings:
|
|
return None, 0.0
|
|
|
|
best_match = None
|
|
best_score = 0.0
|
|
|
|
for name, stored in self.embeddings.items():
|
|
# Cosine similarity
|
|
similarity = np.dot(embedding, stored) / (
|
|
np.linalg.norm(embedding) * np.linalg.norm(stored) + 1e-8
|
|
)
|
|
if similarity > best_score:
|
|
best_score = similarity
|
|
best_match = name
|
|
|
|
if best_score >= threshold:
|
|
meta = self.metadata.get(best_match, {})
|
|
friendly_name = meta.get("name", best_match)
|
|
role = meta.get("role", "unknown")
|
|
if role == "host":
|
|
return f"Host: {friendly_name}", best_score
|
|
return friendly_name, best_score
|
|
|
|
return None, best_score
|
|
|
|
def save_embedding(self, name: str, embedding: np.ndarray,
|
|
role: str = "unknown"):
|
|
"""Save a new voice profile."""
|
|
if role == "host":
|
|
subdir = self.profiles_dir / f"host-{name.lower().replace(' ', '-')}"
|
|
elif role == "guest":
|
|
subdir = self.profiles_dir / "guests"
|
|
elif role == "caller":
|
|
subdir = self.profiles_dir / "callers"
|
|
else:
|
|
subdir = self.profiles_dir / "unknown"
|
|
|
|
subdir.mkdir(parents=True, exist_ok=True)
|
|
filename = name.lower().replace(" ", "-")
|
|
np.save(subdir / f"{filename}.npy", embedding)
|
|
console.print(f"[green]Saved voice profile: {name} ({role})[/green]")
|
|
|
|
|
|
def diarize(audio_path: str | Path,
|
|
voice_profiles: VoiceProfileStore | None = None,
|
|
min_speakers: int = 1,
|
|
max_speakers: int = 6,
|
|
host_match_threshold: float = 0.85,
|
|
transcript_path: str | Path | None = None) -> DiarizationResult:
|
|
"""Run speaker diarization using WavLM sliding-window speaker identification.
|
|
|
|
Uses the built-in VoiceProfiler (WavLM x-vectors) — no HuggingFace token
|
|
or gated model required. Identifies HOST vs non-HOST speakers using the
|
|
stored voice profile for Mike Swanson.
|
|
|
|
If transcript_path is provided, time ranges containing show promo/bumper
|
|
text are pre-marked and skipped at speaker-identification time so vocal
|
|
music doesn't match cohost profiles.
|
|
"""
|
|
import torch
|
|
from .voice_profiler import VoiceProfiler
|
|
|
|
audio_path = Path(audio_path)
|
|
console.print(f"[bold]Diarizing:[/bold] {audio_path.name}")
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
console.print(f"[dim]Device: {device}[/dim]")
|
|
|
|
# Locate voice profiles directory from the VoiceProfileStore path
|
|
profiles_dir = voice_profiles.profiles_dir if voice_profiles else Path("voice-profiles")
|
|
|
|
profiler = VoiceProfiler(profiles_dir, device=device)
|
|
|
|
if not profiler.profiles:
|
|
console.print("[yellow]No voice profiles found — labeling all as HOST[/yellow]")
|
|
# Return a single HOST turn covering the whole episode
|
|
from .voice_profiler import VoiceProfiler as VP
|
|
duration = profiler._get_duration(audio_path)
|
|
return DiarizationResult(
|
|
turns=[SpeakerTurn(speaker="HOST", start=0.0, end=duration)],
|
|
num_speakers=1,
|
|
speaker_map={"HOST": "HOST"},
|
|
)
|
|
|
|
# Pre-compute bumper / promo time ranges from transcript if available
|
|
bumper_ranges: list[tuple[float, float]] = []
|
|
if transcript_path is not None:
|
|
transcript_path = Path(transcript_path)
|
|
if transcript_path.exists():
|
|
from .qa_extractor import _is_promo_or_bumper
|
|
with open(transcript_path) as f:
|
|
tdata = json.load(f)
|
|
for seg in tdata.get("segments", []):
|
|
if _is_promo_or_bumper(seg.get("text", "")):
|
|
bumper_ranges.append((seg["start"], seg["end"]))
|
|
if bumper_ranges:
|
|
console.print(
|
|
f"[dim]Bumper filter: {len(bumper_ranges)} promo/bumper "
|
|
f"transcript segments will be skipped during speaker match[/dim]"
|
|
)
|
|
|
|
# Sliding-window identification: 10s windows, 5s hop
|
|
voice_segs = profiler.identify_speakers(
|
|
audio_path, window_s=10.0, hop_s=5.0,
|
|
threshold=host_match_threshold,
|
|
skip_ranges=bumper_ranges,
|
|
)
|
|
|
|
# Convert VoiceSegment labels to HOST / CALLER
|
|
raw_turns = []
|
|
for seg in voice_segs:
|
|
label = seg.speaker_label.split(" (")[0] # strip confidence score
|
|
if label.startswith("Host:") or label.startswith("Host "):
|
|
speaker = "HOST"
|
|
elif label.startswith("Cohost:"):
|
|
speaker = "CO-HOST"
|
|
elif label == "[bumper]":
|
|
speaker = "BUMPER"
|
|
elif label == "[error]":
|
|
speaker = "UNKNOWN"
|
|
else:
|
|
speaker = "CALLER"
|
|
|
|
raw_turns.append(SpeakerTurn(
|
|
speaker=speaker,
|
|
start=seg.start,
|
|
end=seg.end,
|
|
confidence=float(seg.speaker_label.split("(")[-1].rstrip(")"))
|
|
if "(" in seg.speaker_label else 0.5,
|
|
))
|
|
|
|
# Merge consecutive same-speaker turns
|
|
merged: list[SpeakerTurn] = []
|
|
for turn in raw_turns:
|
|
if merged and merged[-1].speaker == turn.speaker:
|
|
merged[-1].end = turn.end
|
|
else:
|
|
merged.append(SpeakerTurn(
|
|
speaker=turn.speaker,
|
|
start=turn.start,
|
|
end=turn.end,
|
|
confidence=turn.confidence,
|
|
))
|
|
|
|
unique_speakers = set(t.speaker for t in merged)
|
|
speaker_map = {s: s for s in unique_speakers}
|
|
|
|
host_time = sum(t.duration for t in merged if t.speaker == "HOST")
|
|
caller_time = sum(t.duration for t in merged if t.speaker == "CALLER")
|
|
console.print(f"[green]Diarization complete:[/green] {len(merged)} turns | "
|
|
f"HOST {host_time:.0f}s / CALLER {caller_time:.0f}s")
|
|
|
|
return DiarizationResult(
|
|
turns=merged,
|
|
num_speakers=len(unique_speakers),
|
|
speaker_map=speaker_map,
|
|
)
|