# src/multivoice/lib/stt_align_transcription.py
import logging
import torch
from ctc_forced_aligner.alignment_utils import generate_emissions
from ctc_forced_aligner import (
get_alignments,
get_spans,
load_alignment_model,
postprocess_results,
preprocess_text,
)
from multivoice.lib.stt_langs import langs_to_iso
[docs]
def align_transcription(full_transcript, audio_waveform, device, batch_size, info):
"""
Perform forced alignment on a given audio waveform using a pre-defined transcript.
Args:
full_transcript (str): The transcription of the audio.
audio_waveform (np.ndarray): The audio waveform as a NumPy array.
device (str): The device to run the model on, e.g., "cpu" or "cuda".
batch_size (int): The batch size for processing the audio.
info (object): An object containing additional information such as the language.
Returns:
list: A list of word timestamps corresponding to the aligned transcription.
"""
logging.debug("Performing forced alignment...")
# Load the alignment model and tokenizer based on the specified device
alignment_model, alignment_tokenizer = load_alignment_model(
device,
dtype=torch.float16 if device == "cuda" else torch.float32,
)
# Generate emissions from the audio waveform using the alignment model
emissions, stride = generate_emissions(
alignment_model,
torch.from_numpy(audio_waveform)
.to(alignment_model.dtype)
.to(alignment_model.device),
batch_size=batch_size,
)
# Preprocess the full transcript to get tokens and text with special symbols
tokens_starred, text_starred = preprocess_text(
full_transcript,
romanize=True,
language=langs_to_iso[info.language],
)
# Get alignments for the preprocessed tokens using the emissions and tokenizer
segments, scores, blank_token = get_alignments(
emissions,
tokens_starred,
alignment_tokenizer,
)
# Generate spans from the aligned segments and blank token
spans = get_spans(tokens_starred, segments, blank_token)
# Postprocess results to obtain word timestamps with their corresponding scores
word_timestamps = postprocess_results(text_starred, spans, stride, scores)
return word_timestamps