Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 37 additions & 134 deletions backend/services/transcription.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,30 @@
"""
WhisperX-based transcription service with word-level alignment.
Falls back to standard Whisper if WhisperX is not available.
faster-whisper based transcription service with word-level timestamps.
Replaces whisperX to avoid pyannote/speechbrain/k2 dependency issues.
"""

import logging
import os
from pathlib import Path
from typing import Optional

import torch
from faster_whisper import WhisperModel

from utils.gpu_utils import get_optimal_device, configure_gpu
from utils.audio_processing import extract_audio
from utils.cache import load_from_cache, save_to_cache
from utils.audio_processing import extract_audio

logger = logging.getLogger(__name__)

_model_cache: dict = {}

try:
import whisperx
WHISPERX_AVAILABLE = True
except ImportError:
WHISPERX_AVAILABLE = False
import whisper

try:
HF_TOKEN = None
import os
HF_TOKEN = os.environ.get("HF_TOKEN")
except Exception:
pass


def _get_device(use_gpu: bool = True) -> torch.device:
if use_gpu:
return get_optimal_device()
return torch.device("cpu")


def _load_model(model_name: str, device: torch.device):
def _load_model(model_name: str, device: str = "cpu", compute_type: str = "int8"):
cache_key = f"{model_name}_{device}"
if cache_key in _model_cache:
return _model_cache[cache_key]

logger.info(f"Loading model: {model_name} on {device}")
if WHISPERX_AVAILABLE:
compute_type = "float16" if device.type == "cuda" else "int8"
model = whisperx.load_model(
model_name,
device=str(device),
compute_type=compute_type,
)
else:
model = whisper.load_model(model_name, device=device)

model = WhisperModel(model_name, device=device, compute_type=compute_type)
_model_cache[cache_key] = model
return model

Expand All @@ -74,132 +45,64 @@ def transcribe_audio(
file_path = Path(file_path)

if use_cache:
cached = load_from_cache(file_path, model_name, "transcribe_wx")
cached = load_from_cache(file_path, model_name, "transcribe_fw")
if cached:
logger.info("Using cached transcription")
return cached

# Extract audio if video file
video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm"}
if file_path.suffix.lower() in video_extensions:
audio_path = extract_audio(file_path)
else:
audio_path = file_path

device = _get_device(use_gpu)
model = _load_model(model_name, device)

logger.info(f"Transcribing: {file_path}")

if WHISPERX_AVAILABLE:
result = _transcribe_whisperx(model, str(audio_path), device, language)
else:
result = _transcribe_standard(model, str(audio_path), language)
# Use CPU for now (change to "cuda" and "float16" once CUDA is sorted)
device = "cuda"
compute_type = "float16"

if use_cache:
save_to_cache(file_path, result, model_name, "transcribe_wx")

return result
model = _load_model(model_name, device, compute_type)

logger.info(f"Transcribing: {file_path}")

def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: Optional[str]) -> dict:
audio = whisperx.load_audio(audio_path)
transcribe_opts = {}
transcribe_opts = {"word_timestamps": True}
if language:
transcribe_opts["language"] = language

result = model.transcribe(audio, batch_size=16, **transcribe_opts)
detected_language = result.get("language", "en")

align_model, align_metadata = whisperx.load_align_model(
language_code=detected_language,
device=str(device),
)
aligned = whisperx.align(
result["segments"],
align_model,
align_metadata,
audio,
str(device),
return_char_alignments=False,
)
segments_gen, info = model.transcribe(str(audio_path), **transcribe_opts)
detected_language = info.language

words = []
for seg in aligned.get("segments", []):
for w in seg.get("words", []):
words.append({
"word": w.get("word", ""),
"start": round(w.get("start", 0), 3),
"end": round(w.get("end", 0), 3),
"confidence": round(w.get("score", 0), 3),
})

segments = []
for i, seg in enumerate(aligned.get("segments", [])):
seg_words = []
for w in seg.get("words", []):
seg_words.append({
"word": w.get("word", ""),
"start": round(w.get("start", 0), 3),
"end": round(w.get("end", 0), 3),
"confidence": round(w.get("score", 0), 3),
})
segments.append({
"id": i,
"start": round(seg.get("start", 0), 3),
"end": round(seg.get("end", 0), 3),
"text": seg.get("text", "").strip(),
"words": seg_words,
})

return {
"words": words,
"segments": segments,
"language": detected_language,
}


def _transcribe_standard(model, audio_path: str, language: Optional[str]) -> dict:
"""Fallback: standard Whisper (segment-level only, synthesized word timestamps)."""
opts = {}
if language:
opts["language"] = language

result = model.transcribe(audio_path, **opts)
detected_language = result.get("language", "en")

words = []
segments = []

for i, seg in enumerate(result.get("segments", [])):
text = seg.get("text", "").strip()
seg_start = seg.get("start", 0)
seg_end = seg.get("end", 0)
seg_words_text = text.split()
duration = seg_end - seg_start

for i, seg in enumerate(segments_gen):
seg_words = []
for j, w_text in enumerate(seg_words_text):
w_start = seg_start + (j / max(len(seg_words_text), 1)) * duration
w_end = seg_start + ((j + 1) / max(len(seg_words_text), 1)) * duration
word_obj = {
"word": w_text,
"start": round(w_start, 3),
"end": round(w_end, 3),
"confidence": 0.5,
}
words.append(word_obj)
seg_words.append(word_obj)
if seg.words:
for w in seg.words:
word_obj = {
"word": w.word.strip(),
"start": round(w.start, 3),
"end": round(w.end, 3),
"confidence": round(w.probability, 3),
}
words.append(word_obj)
seg_words.append(word_obj)

segments.append({
"id": i,
"start": round(seg_start, 3),
"end": round(seg_end, 3),
"text": text,
"start": round(seg.start, 3),
"end": round(seg.end, 3),
"text": seg.text.strip(),
"words": seg_words,
})

return {
result = {
"words": words,
"segments": segments,
"language": detected_language,
}

if use_cache:
save_to_cache(file_path, result, model_name, "transcribe_fw")

return result