- Fix voice_profiler threshold bug (HOST label overwrote Unknown unconditionally) - Audio preload optimization: single ffmpeg per episode, 149.5x realtime on 5070 Ti - WavLM threshold raised to 0.85 (Mike 0.90-0.99, callers 0.46-0.83) - Promo/bumper filter: weighted signature scoring, 42->27 clean Q&A pairs - Text-only Q&A fallback for episodes with no CALLER diarization labels - TRANSFORMERS_OFFLINE=1 to skip HuggingFace freshness checks - Add diarize_2018.py for targeted re-run + FTS5 rebuild - Add benchmark.py + BENCH_SETUP.md for GURU-BEAST-ROG (RTX 4090) comparison - Commit 9-episode training diarization.json outputs - Session log: 2026-04-27-diarization-pipeline.md Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
149 lines
5.4 KiB
Python
149 lines
5.4 KiB
Python
"""
|
|
Diarize all training episodes, saving diarization.json next to each transcript.
|
|
Then rebuild the archive DB with proper HOST/CALLER labels.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
|
|
# Force UTF-8 output on Windows so Rich's Braille spinner characters don't crash
|
|
os.environ["PYTHONIOENCODING"] = "utf-8"
|
|
if hasattr(sys.stdout, "reconfigure"):
|
|
sys.stdout.reconfigure(encoding="utf-8")
|
|
if hasattr(sys.stderr, "reconfigure"):
|
|
sys.stderr.reconfigure(encoding="utf-8")
|
|
|
|
# Prevent transformers from checking HuggingFace for model updates on every
|
|
# from_pretrained() call — models are already cached locally.
|
|
os.environ["TRANSFORMERS_OFFLINE"] = "1"
|
|
|
|
from pathlib import Path
|
|
|
|
# Ensure CUDA libs before any torch imports
|
|
from src.gpu import ensure_cuda_libs
|
|
ensure_cuda_libs()
|
|
|
|
from src.config import load_config
|
|
from src.diarizer import diarize, VoiceProfileStore
|
|
from src.indexer import ArchiveIndex
|
|
from src.qa_extractor import load_diarized_transcript, extract_qa_pairs, tag_qa_pairs_with_ollama
|
|
from rich.console import Console
|
|
|
|
console = Console()
|
|
|
|
BASE = Path(__file__).parent
|
|
EPISODES_DIR = BASE / "training-data" / "episodes"
|
|
TRANSCRIPTS_DIR = BASE / "training-data" / "transcripts"
|
|
DB_PATH = BASE / "archive" / "archive.db"
|
|
|
|
config = load_config()
|
|
|
|
# Load voice profiles
|
|
voice_profiles = VoiceProfileStore(
|
|
config.resolve_path(config.diarization.voice_profiles_dir)
|
|
)
|
|
|
|
episodes = sorted(EPISODES_DIR.glob("*.mp3"))
|
|
console.print(f"[bold]Diarizing {len(episodes)} training episodes[/bold]")
|
|
|
|
# ── Step 1: Diarize ───────────────────────────────────────────────────────────
|
|
for i, ep_path in enumerate(episodes, 1):
|
|
stem = ep_path.stem
|
|
transcript_dir = TRANSCRIPTS_DIR / stem
|
|
if not transcript_dir.exists():
|
|
console.print(f"[{i}/{len(episodes)}] [yellow]No transcript dir: {stem} — skipping[/yellow]")
|
|
continue
|
|
|
|
diarization_out = transcript_dir / "diarization.json"
|
|
if diarization_out.exists():
|
|
console.print(f"[{i}/{len(episodes)}] [dim]Already diarized: {stem}[/dim]")
|
|
continue
|
|
|
|
console.print(f"\n[{i}/{len(episodes)}] Diarizing: {stem}")
|
|
try:
|
|
result = diarize(ep_path, voice_profiles=voice_profiles,
|
|
min_speakers=config.diarization.min_speakers,
|
|
max_speakers=config.diarization.max_speakers,
|
|
host_match_threshold=0.85)
|
|
result.save(transcript_dir)
|
|
speakers = result.speakers_ranked()
|
|
console.print(f" Done — {len(result.turns)} turns | top speakers: "
|
|
+ ", ".join(f"{s} ({t:.0f}s)" for s, t in speakers[:3]))
|
|
except Exception as e:
|
|
console.print(f" [red]FAILED: {e}[/red]")
|
|
import traceback; traceback.print_exc()
|
|
|
|
# ── Step 2: Rebuild DB ────────────────────────────────────────────────────────
|
|
console.print("\n[bold]Rebuilding archive DB with diarization...[/bold]")
|
|
|
|
if DB_PATH.exists():
|
|
DB_PATH.unlink()
|
|
console.print("[dim]Cleared existing DB[/dim]")
|
|
|
|
import re, json
|
|
|
|
def episode_id(stem):
|
|
return re.sub(r"-hr\d$", "", stem, flags=re.IGNORECASE)
|
|
|
|
with ArchiveIndex(DB_PATH) as idx:
|
|
for ep_path in episodes:
|
|
stem = ep_path.stem
|
|
transcript_dir = TRANSCRIPTS_DIR / stem
|
|
transcript_path = transcript_dir / "transcript.json"
|
|
diarization_path = transcript_dir / "diarization.json"
|
|
|
|
if not transcript_path.exists():
|
|
console.print(f"[yellow]No transcript: {stem} — skipping[/yellow]")
|
|
continue
|
|
|
|
ep_id = episode_id(stem)
|
|
date_m = re.search(r"(\d{4}-\d{2}-\d{2})", stem)
|
|
date = date_m.group(1) if date_m else None
|
|
|
|
with open(transcript_path) as f:
|
|
td = json.load(f)
|
|
duration = td.get("duration")
|
|
|
|
segments = load_diarized_transcript(
|
|
transcript_path,
|
|
diarization_path if diarization_path.exists() else None
|
|
)
|
|
|
|
idx.add_episode(ep_id, ep_path, date=date, duration=duration)
|
|
idx.add_segments(ep_id, segments)
|
|
|
|
# Speaker breakdown
|
|
host_segs = sum(1 for s in segments if s["speaker"] == "HOST")
|
|
caller_segs = sum(1 for s in segments if s["speaker"] in ("CALLER", "UNKNOWN"))
|
|
console.print(f" {ep_id}: {len(segments)} segs "
|
|
f"(HOST={host_segs}, other={caller_segs})")
|
|
|
|
# Extract Q&A pairs
|
|
pairs = extract_qa_pairs(segments)
|
|
console.print(f" {len(pairs)} Q&A pairs", end="")
|
|
|
|
if pairs:
|
|
console.print(f" — tagging with Ollama...", end="")
|
|
pairs = tag_qa_pairs_with_ollama(
|
|
pairs, ollama_host=config.llm.ollama_host, model=config.llm.model
|
|
)
|
|
|
|
for pair in pairs:
|
|
idx.add_qa_pair(
|
|
ep_id,
|
|
pair.question_start, pair.question_end,
|
|
pair.answer_start, pair.answer_end,
|
|
pair.question_text, pair.answer_text,
|
|
topic=pair.topic, tags=pair.topic_tags,
|
|
)
|
|
|
|
console.print()
|
|
|
|
stats = idx.stats()
|
|
|
|
console.print(f"\n[bold green]Done.[/bold green] "
|
|
f"{stats['episodes']} episodes | "
|
|
f"{stats['segments']} segments | "
|
|
f"{stats['qa_pairs']} Q&A pairs")
|
|
console.print(f"DB: {DB_PATH}")
|