Source code for multivoice.lib.stt_align_transcription

# src/multivoice/lib/stt_align_transcription.py

import logging
import torch
from ctc_forced_aligner.alignment_utils import generate_emissions
from ctc_forced_aligner import (
    get_alignments,
    get_spans,
    load_alignment_model,
    postprocess_results,
    preprocess_text,
)
from multivoice.lib.stt_langs import langs_to_iso


[docs] def align_transcription(full_transcript, audio_waveform, device, batch_size, info): """ Perform forced alignment on a given audio waveform using a pre-defined transcript. Args: full_transcript (str): The transcription of the audio. audio_waveform (np.ndarray): The audio waveform as a NumPy array. device (str): The device to run the model on, e.g., "cpu" or "cuda". batch_size (int): The batch size for processing the audio. info (object): An object containing additional information such as the language. Returns: list: A list of word timestamps corresponding to the aligned transcription. """ logging.debug("Performing forced alignment...") # Load the alignment model and tokenizer based on the specified device alignment_model, alignment_tokenizer = load_alignment_model( device, dtype=torch.float16 if device == "cuda" else torch.float32, ) # Generate emissions from the audio waveform using the alignment model emissions, stride = generate_emissions( alignment_model, torch.from_numpy(audio_waveform) .to(alignment_model.dtype) .to(alignment_model.device), batch_size=batch_size, ) # Preprocess the full transcript to get tokens and text with special symbols tokens_starred, text_starred = preprocess_text( full_transcript, romanize=True, language=langs_to_iso[info.language], ) # Get alignments for the preprocessed tokens using the emissions and tokenizer segments, scores, blank_token = get_alignments( emissions, tokens_starred, alignment_tokenizer, ) # Generate spans from the aligned segments and blank token spans = get_spans(tokens_starred, segments, blank_token) # Postprocess results to obtain word timestamps with their corresponding scores word_timestamps = postprocess_results(text_starred, spans, stride, scores) return word_timestamps