diff --git a/backend/routers/transcribe.py b/backend/routers/transcribe.py index c03c209..9116031 100644 --- a/backend/routers/transcribe.py +++ b/backend/routers/transcribe.py @@ -1,9 +1,12 @@ """Transcription endpoint using WhisperX.""" +import asyncio +import json import logging from typing import Optional from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel from services.transcription import transcribe_audio @@ -51,3 +54,53 @@ async def transcribe(req: TranscribeRequest): except Exception as e: logger.error(f"Transcription failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/transcribe/stream") +async def transcribe_stream(req: TranscribeRequest): + """SSE endpoint that streams stage-by-stage progress then returns the result.""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue(maxsize=100) + + def progress_cb(pct: int, label: str): + loop.call_soon_threadsafe(queue.put_nowait, {"progress": pct, "label": label}) + + async def run(): + try: + result = await loop.run_in_executor(None, lambda: transcribe_audio( + file_path=req.file_path, + model_name=req.model, + use_gpu=req.use_gpu, + use_cache=req.use_cache, + language=req.language, + progress_cb=progress_cb, + )) + if req.diarize and req.hf_token: + result = await loop.run_in_executor(None, lambda: diarize_and_label( + transcription_result=result, + audio_path=req.file_path, + hf_token=req.hf_token, + num_speakers=req.num_speakers, + use_gpu=req.use_gpu, + )) + queue.put_nowait({"progress": 100, "label": "Done", "result": result}) + except Exception as e: + logger.error(f"Transcription failed: {e}", exc_info=True) + queue.put_nowait({"error": str(e)}) + finally: + queue.put_nowait(None) + + task = asyncio.create_task(run()) + + async def generate(): + try: + while True: + item = await queue.get() + if item is None: + break + yield f"data: {json.dumps(item)}\n\n" + finally: + # Cancel the background task if the client disconnects before completion. + task.cancel() + + return StreamingResponse(generate(), media_type="text/event-stream") diff --git a/backend/services/transcription.py b/backend/services/transcription.py index 38b9f34..9e73bb7 100644 --- a/backend/services/transcription.py +++ b/backend/services/transcription.py @@ -64,6 +64,7 @@ def transcribe_audio( use_gpu: bool = True, use_cache: bool = True, language: Optional[str] = None, + progress_cb=None, ) -> dict: """ Transcribe audio/video file and return word-level timestamps. @@ -73,14 +74,19 @@ def transcribe_audio( """ file_path = Path(file_path) + progress = progress_cb or (lambda pct, label: None) + if use_cache: cached = load_from_cache(file_path, model_name, "transcribe_wx") if cached: logger.info("Using cached transcription") + progress(100, "Done (cached)") return cached + progress(5, "Loading model…") video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm"} if file_path.suffix.lower() in video_extensions: + progress(10, "Extracting audio…") audio_path = extract_audio(file_path) else: audio_path = file_path @@ -91,7 +97,7 @@ def transcribe_audio( logger.info(f"Transcribing: {file_path}") if WHISPERX_AVAILABLE: - result = _transcribe_whisperx(model, str(audio_path), device, language) + result = _transcribe_whisperx(model, str(audio_path), device, language, progress) else: result = _transcribe_standard(model, str(audio_path), language) @@ -101,15 +107,18 @@ def transcribe_audio( return result -def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: Optional[str]) -> dict: +def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: Optional[str], progress) -> dict: + progress(20, "Detecting speech…") audio = whisperx.load_audio(audio_path) transcribe_opts = {} if language: transcribe_opts["language"] = language + progress(30, "Transcribing…") result = model.transcribe(audio, batch_size=16, **transcribe_opts) detected_language = result.get("language", "en") + progress(70, "Aligning words…") align_model, align_metadata = whisperx.load_align_model( language_code=detected_language, device=str(device), @@ -122,6 +131,7 @@ def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: str(device), return_char_alignments=False, ) + progress(90, "Finalizing…") words = [] for seg in aligned.get("segments", []): diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 8542392..1452fea 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -38,6 +38,7 @@ export default function App() { const [activePanel, setActivePanel] = useState(null); const [manualPath, setManualPath] = useState(''); const [whisperModel, setWhisperModel] = useState('base'); + const [transcriptionLabel, setTranscriptionLabel] = useState(''); const fileInputRef = useRef(null); useKeyboardShortcuts(); @@ -89,20 +90,44 @@ export default function App() { const transcribeVideo = async (path: string) => { setTranscribing(true, 0); + setTranscriptionLabel('Starting…'); try { - const res = await fetch(`${backendUrl}/transcribe`, { + const res = await fetch(`${backendUrl}/transcribe/stream`, { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ file_path: path, model: whisperModel }), }); if (!res.ok) throw new Error(`Transcription failed: ${res.statusText}`); - const data = await res.json(); - setTranscription(data); + + const reader = res.body!.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split('\n'); + buffer = lines.pop()!; + for (const line of lines) { + if (!line.startsWith('data: ')) continue; + const event = JSON.parse(line.slice(6)); + if (event.error) throw new Error(event.error); + setTranscribing(true, event.progress); + setTranscriptionLabel(event.label ?? ''); + if (event.result) setTranscription(event.result); + } + } + } finally { + reader.cancel(); + } } catch (err) { console.error('Transcription error:', err); alert(`Transcription failed. Check the console for details.\n\n${err}`); } finally { setTranscribing(false); + setTranscriptionLabel(''); } }; @@ -247,7 +272,7 @@ export default function App() {

- Transcribing... {Math.round(transcriptionProgress)}% + {transcriptionLabel || 'Transcribing…'} {Math.round(transcriptionProgress)}%

) : words.length > 0 ? (