From 403cbd4d15f711edfa84ce7f3b5ab26aa0bb8b64 Mon Sep 17 00:00:00 2001 From: vashdao Date: Mon, 20 Apr 2026 00:49:32 +1000 Subject: [PATCH] fix: bypass whisperx with faster-whisper direct for CUDA 13 / RTX 5000 series compatibility --- backend/services/transcription.py | 171 +++++++----------------------- 1 file changed, 37 insertions(+), 134 deletions(-) diff --git a/backend/services/transcription.py b/backend/services/transcription.py index 38b9f34..529af16 100644 --- a/backend/services/transcription.py +++ b/backend/services/transcription.py @@ -1,59 +1,30 @@ """ -WhisperX-based transcription service with word-level alignment. -Falls back to standard Whisper if WhisperX is not available. +faster-whisper based transcription service with word-level timestamps. +Replaces whisperX to avoid pyannote/speechbrain/k2 dependency issues. """ import logging +import os from pathlib import Path from typing import Optional -import torch +from faster_whisper import WhisperModel -from utils.gpu_utils import get_optimal_device, configure_gpu -from utils.audio_processing import extract_audio from utils.cache import load_from_cache, save_to_cache +from utils.audio_processing import extract_audio logger = logging.getLogger(__name__) _model_cache: dict = {} -try: - import whisperx - WHISPERX_AVAILABLE = True -except ImportError: - WHISPERX_AVAILABLE = False - import whisper - -try: - HF_TOKEN = None - import os - HF_TOKEN = os.environ.get("HF_TOKEN") -except Exception: - pass - - -def _get_device(use_gpu: bool = True) -> torch.device: - if use_gpu: - return get_optimal_device() - return torch.device("cpu") - -def _load_model(model_name: str, device: torch.device): +def _load_model(model_name: str, device: str = "cpu", compute_type: str = "int8"): cache_key = f"{model_name}_{device}" if cache_key in _model_cache: return _model_cache[cache_key] logger.info(f"Loading model: {model_name} on {device}") - if WHISPERX_AVAILABLE: - compute_type = "float16" if device.type == "cuda" else "int8" - model = whisperx.load_model( - model_name, - device=str(device), - compute_type=compute_type, - ) - else: - model = whisper.load_model(model_name, device=device) - + model = WhisperModel(model_name, device=device, compute_type=compute_type) _model_cache[cache_key] = model return model @@ -74,132 +45,64 @@ def transcribe_audio( file_path = Path(file_path) if use_cache: - cached = load_from_cache(file_path, model_name, "transcribe_wx") + cached = load_from_cache(file_path, model_name, "transcribe_fw") if cached: logger.info("Using cached transcription") return cached + # Extract audio if video file video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm"} if file_path.suffix.lower() in video_extensions: audio_path = extract_audio(file_path) else: audio_path = file_path - device = _get_device(use_gpu) - model = _load_model(model_name, device) - - logger.info(f"Transcribing: {file_path}") - - if WHISPERX_AVAILABLE: - result = _transcribe_whisperx(model, str(audio_path), device, language) - else: - result = _transcribe_standard(model, str(audio_path), language) + # Use CPU for now (change to "cuda" and "float16" once CUDA is sorted) + device = "cuda" + compute_type = "float16" - if use_cache: - save_to_cache(file_path, result, model_name, "transcribe_wx") - - return result + model = _load_model(model_name, device, compute_type) + logger.info(f"Transcribing: {file_path}") -def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: Optional[str]) -> dict: - audio = whisperx.load_audio(audio_path) - transcribe_opts = {} + transcribe_opts = {"word_timestamps": True} if language: transcribe_opts["language"] = language - result = model.transcribe(audio, batch_size=16, **transcribe_opts) - detected_language = result.get("language", "en") - - align_model, align_metadata = whisperx.load_align_model( - language_code=detected_language, - device=str(device), - ) - aligned = whisperx.align( - result["segments"], - align_model, - align_metadata, - audio, - str(device), - return_char_alignments=False, - ) + segments_gen, info = model.transcribe(str(audio_path), **transcribe_opts) + detected_language = info.language words = [] - for seg in aligned.get("segments", []): - for w in seg.get("words", []): - words.append({ - "word": w.get("word", ""), - "start": round(w.get("start", 0), 3), - "end": round(w.get("end", 0), 3), - "confidence": round(w.get("score", 0), 3), - }) - segments = [] - for i, seg in enumerate(aligned.get("segments", [])): - seg_words = [] - for w in seg.get("words", []): - seg_words.append({ - "word": w.get("word", ""), - "start": round(w.get("start", 0), 3), - "end": round(w.get("end", 0), 3), - "confidence": round(w.get("score", 0), 3), - }) - segments.append({ - "id": i, - "start": round(seg.get("start", 0), 3), - "end": round(seg.get("end", 0), 3), - "text": seg.get("text", "").strip(), - "words": seg_words, - }) - - return { - "words": words, - "segments": segments, - "language": detected_language, - } - - -def _transcribe_standard(model, audio_path: str, language: Optional[str]) -> dict: - """Fallback: standard Whisper (segment-level only, synthesized word timestamps).""" - opts = {} - if language: - opts["language"] = language - - result = model.transcribe(audio_path, **opts) - detected_language = result.get("language", "en") - - words = [] - segments = [] - - for i, seg in enumerate(result.get("segments", [])): - text = seg.get("text", "").strip() - seg_start = seg.get("start", 0) - seg_end = seg.get("end", 0) - seg_words_text = text.split() - duration = seg_end - seg_start + for i, seg in enumerate(segments_gen): seg_words = [] - for j, w_text in enumerate(seg_words_text): - w_start = seg_start + (j / max(len(seg_words_text), 1)) * duration - w_end = seg_start + ((j + 1) / max(len(seg_words_text), 1)) * duration - word_obj = { - "word": w_text, - "start": round(w_start, 3), - "end": round(w_end, 3), - "confidence": 0.5, - } - words.append(word_obj) - seg_words.append(word_obj) + if seg.words: + for w in seg.words: + word_obj = { + "word": w.word.strip(), + "start": round(w.start, 3), + "end": round(w.end, 3), + "confidence": round(w.probability, 3), + } + words.append(word_obj) + seg_words.append(word_obj) segments.append({ "id": i, - "start": round(seg_start, 3), - "end": round(seg_end, 3), - "text": text, + "start": round(seg.start, 3), + "end": round(seg.end, 3), + "text": seg.text.strip(), "words": seg_words, }) - return { + result = { "words": words, "segments": segments, "language": detected_language, } + + if use_cache: + save_to_cache(file_path, result, model_name, "transcribe_fw") + + return result