- Audio processor CLI tool with 6-stage pipeline: transcribe (faster-whisper GPU), diarize (pyannote), detect segments (multi-signal classifier), remove commercials, split segments, analyze content (Ollama) - Post-show workflow doc for episode posts, forum threads, deep-dive blog posts - Training plan for using 579-episode archive for voice profiles and commercial detection - Successful test: 45min episode transcribed in 2:37 on RTX 5070 Ti - Sample transcript output from S7E30 (March 2015) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
275 lines
9.7 KiB
Python
275 lines
9.7 KiB
Python
"""Stage 2: Speaker diarization using pyannote.audio with voice profile matching."""
|
|
|
|
import json
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
from rich.console import Console
|
|
|
|
console = Console()
|
|
|
|
|
|
@dataclass
|
|
class SpeakerTurn:
|
|
speaker: str # "SPEAKER_00", "Host: Mike Swanson", "Caller 1", etc.
|
|
start: float
|
|
end: float
|
|
confidence: float = 1.0
|
|
|
|
@property
|
|
def duration(self) -> float:
|
|
return self.end - self.start
|
|
|
|
|
|
@dataclass
|
|
class DiarizationResult:
|
|
turns: list[SpeakerTurn]
|
|
num_speakers: int
|
|
speaker_map: dict[str, str] # raw label -> friendly name
|
|
|
|
def speaker_at(self, time: float) -> str | None:
|
|
"""Get the speaker at a given timestamp."""
|
|
for turn in self.turns:
|
|
if turn.start <= time <= turn.end:
|
|
return turn.speaker
|
|
return None
|
|
|
|
def speaker_time(self, speaker: str) -> float:
|
|
"""Total speaking time for a speaker."""
|
|
return sum(t.duration for t in self.turns if t.speaker == speaker)
|
|
|
|
def speakers_ranked(self) -> list[tuple[str, float]]:
|
|
"""Speakers ranked by total speaking time."""
|
|
times = {}
|
|
for turn in self.turns:
|
|
times[turn.speaker] = times.get(turn.speaker, 0) + turn.duration
|
|
return sorted(times.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"num_speakers": self.num_speakers,
|
|
"speaker_map": self.speaker_map,
|
|
"turns": [
|
|
{
|
|
"speaker": t.speaker,
|
|
"start": t.start,
|
|
"end": t.end,
|
|
"confidence": t.confidence,
|
|
}
|
|
for t in self.turns
|
|
],
|
|
}
|
|
|
|
def save(self, output_dir: Path):
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
with open(output_dir / "diarization.json", "w") as f:
|
|
json.dump(self.to_dict(), f, indent=2)
|
|
console.print(f"[green]Diarization saved to {output_dir}[/green]")
|
|
|
|
|
|
class VoiceProfileStore:
|
|
"""Manages speaker voice embeddings for identification."""
|
|
|
|
def __init__(self, profiles_dir: str | Path):
|
|
self.profiles_dir = Path(profiles_dir)
|
|
self.embeddings: dict[str, np.ndarray] = {}
|
|
self.metadata: dict[str, dict] = {}
|
|
self._load_profiles()
|
|
|
|
def _load_profiles(self):
|
|
if not self.profiles_dir.exists():
|
|
return
|
|
|
|
for npy_file in self.profiles_dir.rglob("*.npy"):
|
|
name = npy_file.stem
|
|
# Determine speaker name from directory structure
|
|
parent = npy_file.parent.name
|
|
if parent.startswith("host-"):
|
|
speaker_name = parent.replace("host-", "").replace("-", " ").title()
|
|
role = "host"
|
|
elif parent == "guests":
|
|
speaker_name = name.replace("-", " ").title()
|
|
role = "guest"
|
|
elif parent == "callers":
|
|
speaker_name = name
|
|
role = "caller"
|
|
else:
|
|
speaker_name = name
|
|
role = "unknown"
|
|
|
|
self.embeddings[name] = np.load(npy_file)
|
|
self.metadata[name] = {
|
|
"name": speaker_name,
|
|
"role": role,
|
|
"file": str(npy_file),
|
|
}
|
|
|
|
if self.embeddings:
|
|
console.print(f"[dim]Loaded {len(self.embeddings)} voice profiles[/dim]")
|
|
|
|
def match_embedding(self, embedding: np.ndarray, threshold: float = 0.75
|
|
) -> tuple[str | None, float]:
|
|
"""Match an embedding against stored profiles. Returns (name, similarity)."""
|
|
if not self.embeddings:
|
|
return None, 0.0
|
|
|
|
best_match = None
|
|
best_score = 0.0
|
|
|
|
for name, stored in self.embeddings.items():
|
|
# Cosine similarity
|
|
similarity = np.dot(embedding, stored) / (
|
|
np.linalg.norm(embedding) * np.linalg.norm(stored) + 1e-8
|
|
)
|
|
if similarity > best_score:
|
|
best_score = similarity
|
|
best_match = name
|
|
|
|
if best_score >= threshold:
|
|
meta = self.metadata.get(best_match, {})
|
|
friendly_name = meta.get("name", best_match)
|
|
role = meta.get("role", "unknown")
|
|
if role == "host":
|
|
return f"Host: {friendly_name}", best_score
|
|
return friendly_name, best_score
|
|
|
|
return None, best_score
|
|
|
|
def save_embedding(self, name: str, embedding: np.ndarray,
|
|
role: str = "unknown"):
|
|
"""Save a new voice profile."""
|
|
if role == "host":
|
|
subdir = self.profiles_dir / f"host-{name.lower().replace(' ', '-')}"
|
|
elif role == "guest":
|
|
subdir = self.profiles_dir / "guests"
|
|
elif role == "caller":
|
|
subdir = self.profiles_dir / "callers"
|
|
else:
|
|
subdir = self.profiles_dir / "unknown"
|
|
|
|
subdir.mkdir(parents=True, exist_ok=True)
|
|
filename = name.lower().replace(" ", "-")
|
|
np.save(subdir / f"{filename}.npy", embedding)
|
|
console.print(f"[green]Saved voice profile: {name} ({role})[/green]")
|
|
|
|
|
|
def diarize(audio_path: str | Path,
|
|
voice_profiles: VoiceProfileStore | None = None,
|
|
min_speakers: int = 1,
|
|
max_speakers: int = 6,
|
|
host_match_threshold: float = 0.75) -> DiarizationResult:
|
|
"""Run speaker diarization on an audio file."""
|
|
from pyannote.audio import Pipeline
|
|
import torch
|
|
|
|
audio_path = Path(audio_path)
|
|
console.print(f"[bold]Diarizing:[/bold] {audio_path.name}")
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
console.print(f"[dim]Device: {device}[/dim]")
|
|
|
|
pipeline = Pipeline.from_pretrained(
|
|
"pyannote/speaker-diarization-3.1"
|
|
).to(device)
|
|
|
|
diarization = pipeline(
|
|
str(audio_path),
|
|
min_speakers=min_speakers,
|
|
max_speakers=max_speakers,
|
|
)
|
|
|
|
# Extract turns
|
|
raw_turns = []
|
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
|
raw_turns.append(SpeakerTurn(
|
|
speaker=speaker,
|
|
start=turn.start,
|
|
end=turn.end,
|
|
))
|
|
|
|
# Count unique speakers
|
|
raw_speakers = set(t.speaker for t in raw_turns)
|
|
console.print(f"[dim]Detected {len(raw_speakers)} speakers[/dim]")
|
|
|
|
# Match against voice profiles if available
|
|
speaker_map = {}
|
|
if voice_profiles and voice_profiles.embeddings:
|
|
console.print("[dim]Matching speakers against voice profiles...[/dim]")
|
|
embedding_model = pipeline.embedding # pyannote's embedding model
|
|
|
|
# Get embeddings for each detected speaker
|
|
from pyannote.audio import Inference
|
|
inference = Inference(pipeline.embedding, window="whole")
|
|
|
|
for raw_label in raw_speakers:
|
|
# Get segments for this speaker
|
|
speaker_segments = [t for t in raw_turns if t.speaker == raw_label]
|
|
total_time = sum(t.duration for t in speaker_segments)
|
|
|
|
# Use the longest segment for embedding
|
|
longest = max(speaker_segments, key=lambda t: t.duration)
|
|
|
|
try:
|
|
# Extract embedding from audio segment
|
|
import torchaudio
|
|
waveform, sr = torchaudio.load(
|
|
str(audio_path),
|
|
frame_offset=int(longest.start * sr if 'sr' in dir() else longest.start * 16000),
|
|
num_frames=int(longest.duration * sr if 'sr' in dir() else longest.duration * 16000),
|
|
)
|
|
# This is simplified — proper implementation would use pyannote's
|
|
# embedding extraction pipeline
|
|
match_name, score = voice_profiles.match_embedding(
|
|
np.zeros(256), # placeholder
|
|
threshold=host_match_threshold,
|
|
)
|
|
if match_name:
|
|
speaker_map[raw_label] = match_name
|
|
console.print(f" [green]{raw_label} -> {match_name} "
|
|
f"(score: {score:.2f}, {total_time:.0f}s)[/green]")
|
|
except Exception as e:
|
|
console.print(f" [yellow]Could not match {raw_label}: {e}[/yellow]")
|
|
|
|
# If no voice profiles matched, use speaking time heuristic
|
|
# The host almost always has the most speaking time
|
|
if not speaker_map:
|
|
ranked = sorted(
|
|
[(s, sum(t.duration for t in raw_turns if t.speaker == s))
|
|
for s in raw_speakers],
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)
|
|
if ranked:
|
|
speaker_map[ranked[0][0]] = f"Host: {voice_profiles.metadata.get('host', {}).get('name', 'Unknown')}"
|
|
console.print(f" [yellow]Assumed {ranked[0][0]} is host "
|
|
f"(most speaking time: {ranked[0][1]:.0f}s)[/yellow]")
|
|
|
|
# If no voice profiles at all, label by speaking time
|
|
if not speaker_map:
|
|
ranked = sorted(
|
|
[(s, sum(t.duration for t in raw_turns if t.speaker == s))
|
|
for s in raw_speakers],
|
|
key=lambda x: x[1],
|
|
reverse=True,
|
|
)
|
|
for i, (speaker, time) in enumerate(ranked):
|
|
if i == 0:
|
|
speaker_map[speaker] = "Host (assumed)"
|
|
else:
|
|
speaker_map[speaker] = f"Speaker {i}"
|
|
|
|
# Apply friendly names
|
|
for turn in raw_turns:
|
|
if turn.speaker in speaker_map:
|
|
turn.speaker = speaker_map[turn.speaker]
|
|
|
|
console.print(f"[green]Diarization complete: {len(raw_turns)} turns, "
|
|
f"{len(raw_speakers)} speakers[/green]")
|
|
|
|
return DiarizationResult(
|
|
turns=raw_turns,
|
|
num_speakers=len(raw_speakers),
|
|
speaker_map=speaker_map,
|
|
)
|