Source code for multivoice.lib.stt_transcribe_audio

# src/multivoice/lib/stt_transcribe_audio.py

import faster_whisper
import logging
import torch


[docs] def find_numeral_symbol_tokens(tokenizer): """ Identifies tokens in the tokenizer's vocabulary that contain numeral symbols. Args: tokenizer: The tokenizer object containing the model's vocabulary. Returns: A list of token IDs corresponding to tokens that include numeral symbols. """ numeral_symbol_tokens = [ -1, ] for token, token_id in tokenizer.get_vocab().items(): has_numeral_symbol = any(c in "0123456789%$£" for c in token) if has_numeral_symbol: numeral_symbol_tokens.append(token_id) return numeral_symbol_tokens
[docs] def transcribe_audio( vocal_target, language, suppress_numerals, batch_size, device, mtypes, args ): """ Transcribes the audio file using a Whisper model. Args: vocal_target: The path to the audio file or the preprocessed audio data. language: The target language for transcription. suppress_numerals: Boolean flag to indicate whether numeral symbols should be suppressed in the transcription. batch_size: The number of segments to process in one batch. If 0, processes without batching. device: The device (CPU or GPU) on which to run the model. mtypes: A dictionary mapping devices to their respective compute types. args: Namespace object containing additional arguments and configurations. Returns: full_transcript: The complete transcribed text as a string. info: Additional information about the transcription process, such as language detection details. """ logging.debug("Transcribing the audio file...") whisper_model = faster_whisper.WhisperModel( args.model_name, device=device, compute_type=mtypes[device] ) whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) audio_waveform = faster_whisper.decode_audio(vocal_target) suppress_tokens = ( find_numeral_symbol_tokens(whisper_model.hf_tokenizer) if args.suppress_numerals else [-1] ) if batch_size > 0: transcript_segments, info = whisper_pipeline.transcribe( audio_waveform, language, suppress_tokens=suppress_tokens, batch_size=batch_size, ) else: transcript_segments, info = whisper_model.transcribe( audio_waveform, language, suppress_tokens=suppress_tokens, vad_filter=True, ) full_transcript = "".join(segment.text for segment in transcript_segments) # Clear GPU VRAM del whisper_model, whisper_pipeline torch.cuda.empty_cache() return full_transcript, info