# src/multivoice/lib/stt_transcribe_audio.py
import faster_whisper
import logging
import torch
[docs]
def find_numeral_symbol_tokens(tokenizer):
"""
Identifies tokens in the tokenizer's vocabulary that contain numeral symbols.
Args:
tokenizer: The tokenizer object containing the model's vocabulary.
Returns:
A list of token IDs corresponding to tokens that include numeral symbols.
"""
numeral_symbol_tokens = [
-1,
]
for token, token_id in tokenizer.get_vocab().items():
has_numeral_symbol = any(c in "0123456789%$£" for c in token)
if has_numeral_symbol:
numeral_symbol_tokens.append(token_id)
return numeral_symbol_tokens
[docs]
def transcribe_audio(
vocal_target, language, suppress_numerals, batch_size, device, mtypes, args
):
"""
Transcribes the audio file using a Whisper model.
Args:
vocal_target: The path to the audio file or the preprocessed audio data.
language: The target language for transcription.
suppress_numerals: Boolean flag to indicate whether numeral symbols should be suppressed in the transcription.
batch_size: The number of segments to process in one batch. If 0, processes without batching.
device: The device (CPU or GPU) on which to run the model.
mtypes: A dictionary mapping devices to their respective compute types.
args: Namespace object containing additional arguments and configurations.
Returns:
full_transcript: The complete transcribed text as a string.
info: Additional information about the transcription process, such as language detection details.
"""
logging.debug("Transcribing the audio file...")
whisper_model = faster_whisper.WhisperModel(
args.model_name, device=device, compute_type=mtypes[device]
)
whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model)
audio_waveform = faster_whisper.decode_audio(vocal_target)
suppress_tokens = (
find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
if args.suppress_numerals
else [-1]
)
if batch_size > 0:
transcript_segments, info = whisper_pipeline.transcribe(
audio_waveform,
language,
suppress_tokens=suppress_tokens,
batch_size=batch_size,
)
else:
transcript_segments, info = whisper_model.transcribe(
audio_waveform,
language,
suppress_tokens=suppress_tokens,
vad_filter=True,
)
full_transcript = "".join(segment.text for segment in transcript_segments)
# Clear GPU VRAM
del whisper_model, whisper_pipeline
torch.cuda.empty_cache()
return full_transcript, info