Files
claudetools/projects/radio-show/audio-processor/diarize_training.py
Mike Swanson 79abef9dc9 radio: diarization pipeline fixes, benchmark setup, test episode set
- Fix voice_profiler threshold bug (HOST label overwrote Unknown unconditionally)
- Audio preload optimization: single ffmpeg per episode, 149.5x realtime on 5070 Ti
- WavLM threshold raised to 0.85 (Mike 0.90-0.99, callers 0.46-0.83)
- Promo/bumper filter: weighted signature scoring, 42->27 clean Q&A pairs
- Text-only Q&A fallback for episodes with no CALLER diarization labels
- TRANSFORMERS_OFFLINE=1 to skip HuggingFace freshness checks
- Add diarize_2018.py for targeted re-run + FTS5 rebuild
- Add benchmark.py + BENCH_SETUP.md for GURU-BEAST-ROG (RTX 4090) comparison
- Commit 9-episode training diarization.json outputs
- Session log: 2026-04-27-diarization-pipeline.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-27 13:20:40 -07:00

149 lines
5.4 KiB
Python

"""
Diarize all training episodes, saving diarization.json next to each transcript.
Then rebuild the archive DB with proper HOST/CALLER labels.
"""
import sys
import os
# Force UTF-8 output on Windows so Rich's Braille spinner characters don't crash
os.environ["PYTHONIOENCODING"] = "utf-8"
if hasattr(sys.stdout, "reconfigure"):
sys.stdout.reconfigure(encoding="utf-8")
if hasattr(sys.stderr, "reconfigure"):
sys.stderr.reconfigure(encoding="utf-8")
# Prevent transformers from checking HuggingFace for model updates on every
# from_pretrained() call — models are already cached locally.
os.environ["TRANSFORMERS_OFFLINE"] = "1"
from pathlib import Path
# Ensure CUDA libs before any torch imports
from src.gpu import ensure_cuda_libs
ensure_cuda_libs()
from src.config import load_config
from src.diarizer import diarize, VoiceProfileStore
from src.indexer import ArchiveIndex
from src.qa_extractor import load_diarized_transcript, extract_qa_pairs, tag_qa_pairs_with_ollama
from rich.console import Console
console = Console()
BASE = Path(__file__).parent
EPISODES_DIR = BASE / "training-data" / "episodes"
TRANSCRIPTS_DIR = BASE / "training-data" / "transcripts"
DB_PATH = BASE / "archive" / "archive.db"
config = load_config()
# Load voice profiles
voice_profiles = VoiceProfileStore(
config.resolve_path(config.diarization.voice_profiles_dir)
)
episodes = sorted(EPISODES_DIR.glob("*.mp3"))
console.print(f"[bold]Diarizing {len(episodes)} training episodes[/bold]")
# ── Step 1: Diarize ───────────────────────────────────────────────────────────
for i, ep_path in enumerate(episodes, 1):
stem = ep_path.stem
transcript_dir = TRANSCRIPTS_DIR / stem
if not transcript_dir.exists():
console.print(f"[{i}/{len(episodes)}] [yellow]No transcript dir: {stem} — skipping[/yellow]")
continue
diarization_out = transcript_dir / "diarization.json"
if diarization_out.exists():
console.print(f"[{i}/{len(episodes)}] [dim]Already diarized: {stem}[/dim]")
continue
console.print(f"\n[{i}/{len(episodes)}] Diarizing: {stem}")
try:
result = diarize(ep_path, voice_profiles=voice_profiles,
min_speakers=config.diarization.min_speakers,
max_speakers=config.diarization.max_speakers,
host_match_threshold=0.85)
result.save(transcript_dir)
speakers = result.speakers_ranked()
console.print(f" Done — {len(result.turns)} turns | top speakers: "
+ ", ".join(f"{s} ({t:.0f}s)" for s, t in speakers[:3]))
except Exception as e:
console.print(f" [red]FAILED: {e}[/red]")
import traceback; traceback.print_exc()
# ── Step 2: Rebuild DB ────────────────────────────────────────────────────────
console.print("\n[bold]Rebuilding archive DB with diarization...[/bold]")
if DB_PATH.exists():
DB_PATH.unlink()
console.print("[dim]Cleared existing DB[/dim]")
import re, json
def episode_id(stem):
return re.sub(r"-hr\d$", "", stem, flags=re.IGNORECASE)
with ArchiveIndex(DB_PATH) as idx:
for ep_path in episodes:
stem = ep_path.stem
transcript_dir = TRANSCRIPTS_DIR / stem
transcript_path = transcript_dir / "transcript.json"
diarization_path = transcript_dir / "diarization.json"
if not transcript_path.exists():
console.print(f"[yellow]No transcript: {stem} — skipping[/yellow]")
continue
ep_id = episode_id(stem)
date_m = re.search(r"(\d{4}-\d{2}-\d{2})", stem)
date = date_m.group(1) if date_m else None
with open(transcript_path) as f:
td = json.load(f)
duration = td.get("duration")
segments = load_diarized_transcript(
transcript_path,
diarization_path if diarization_path.exists() else None
)
idx.add_episode(ep_id, ep_path, date=date, duration=duration)
idx.add_segments(ep_id, segments)
# Speaker breakdown
host_segs = sum(1 for s in segments if s["speaker"] == "HOST")
caller_segs = sum(1 for s in segments if s["speaker"] in ("CALLER", "UNKNOWN"))
console.print(f" {ep_id}: {len(segments)} segs "
f"(HOST={host_segs}, other={caller_segs})")
# Extract Q&A pairs
pairs = extract_qa_pairs(segments)
console.print(f" {len(pairs)} Q&A pairs", end="")
if pairs:
console.print(f" — tagging with Ollama...", end="")
pairs = tag_qa_pairs_with_ollama(
pairs, ollama_host=config.llm.ollama_host, model=config.llm.model
)
for pair in pairs:
idx.add_qa_pair(
ep_id,
pair.question_start, pair.question_end,
pair.answer_start, pair.answer_end,
pair.question_text, pair.answer_text,
topic=pair.topic, tags=pair.topic_tags,
)
console.print()
stats = idx.stats()
console.print(f"\n[bold green]Done.[/bold green] "
f"{stats['episodes']} episodes | "
f"{stats['segments']} segments | "
f"{stats['qa_pairs']} Q&A pairs")
console.print(f"DB: {DB_PATH}")