Source code for multivoice.lib.stt_punctuation

# src/multivoice/lib/stt_punctuation.py

import logging
import re
from deepmultilingualpunctuation import PunctuationModel
from multivoice.lib.stt_langs import punct_model_langs


[docs] def restore_punctuation(wsm, info): """ Restores punctuation to a list of words based on the detected language. Args: wsm (list): A list of dictionaries containing word and speaker information. info (object): An object containing information about the language. Returns: list: The modified list of dictionaries with restored punctuation. """ logging.debug("Restoring punctuation...") if info.language in punct_model_langs: # Initialize PunctuationModel with a specific model for punctuation restoration punct_model = PunctuationModel(model="kredor/punctuate-all") # Extract words from the word_speaker_mapping list words_list = list(map(lambda x: x["word"], wsm)) # Get predictions for each word on where to place punctuation labled_words = punct_model.predict(words_list, chunk_size=230) ending_puncts = ".?!" model_puncts = ".,;:!?" def is_acronym(x): """ Checks if a given string is an acronym. Args: x (str): The word to check. Returns: bool: True if the word is an acronym, False otherwise. """ return re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x) # Iterate over each word and its corresponding label from the prediction for word_dict, labeled_tuple in zip(wsm, labled_words): word = word_dict["word"] # Add punctuation if it is not already present or if it's an acronym if ( word and labeled_tuple[1] in ending_puncts and (word[-1] not in model_puncts or is_acronym(word)) ): word += labeled_tuple[1] # Avoid double punctuation by stripping out extra dots if word.endswith(".."): word = word.rstrip(".") word_dict["word"] = word else: logging.warning( f"Punctuation restoration is not available for {info.language} language." " Using the original punctuation." ) return wsm
sentence_ending_punctuations = ".?!" # Define sentence-ending punctuations
[docs] def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words): """ Finds the index of the first word in a sentence based on speaker continuity and punctuation. Args: word_idx (int): The current word index. word_list (list): A list of words. speaker_list (list): A list of speakers corresponding to each word. max_words (int): Maximum number of words to consider as part of the same sentence. Returns: int: Index of the first word in the sentence or -1 if not found. """ def is_word_sentence_end(x): """ Checks if a word at index x is a sentence-ending punctuation. Args: x (int): The index of the word to check. Returns: bool: True if the word ends with a sentence-ending punctuation, False otherwise. """ return x >= 0 and word_list[x][-1] in sentence_ending_punctuations left_idx = word_idx while ( left_idx > 0 and word_idx - left_idx < max_words and speaker_list[left_idx - 1] == speaker_list[left_idx] and not is_word_sentence_end(left_idx - 1) ): left_idx -= 1 return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
[docs] def get_last_word_idx_of_sentence(word_idx, word_list, max_words): """ Finds the index of the last word in a sentence based on punctuation. Args: word_idx (int): The current word index. word_list (list): A list of words. max_words (int): Maximum number of words to consider as part of the same sentence. Returns: int: Index of the last word in the sentence or -1 if not found. """ def is_word_sentence_end(x): """ Checks if a word at index x is a sentence-ending punctuation. Args: x (int): The index of the word to check. Returns: bool: True if the word ends with a sentence-ending punctuation, False otherwise. """ return x >= 0 and word_list[x][-1] in sentence_ending_punctuations right_idx = word_idx while ( right_idx < len(word_list) - 1 and right_idx - word_idx < max_words and not is_word_sentence_end(right_idx) ): right_idx += 1 return ( right_idx if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx) else -1 )
[docs] def get_realigned_ws_mapping_with_punctuation( word_speaker_mapping, max_words_in_sentence=50 ): """ Realigns the speaker mapping with punctuation to ensure consistent speaker labels for sentences. Args: word_speaker_mapping (list): A list of dictionaries containing words and their corresponding speakers. max_words_in_sentence (int): Maximum number of words considered in a sentence for alignment. Returns: list: The realigned list of dictionaries with updated speaker information. """ def is_word_sentence_end(x): """ Checks if a word at index x is a sentence-ending punctuation. Args: x (int): The index of the word to check. Returns: bool: True if the word ends with a sentence-ending punctuation, False otherwise. """ return ( x >= 0 and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations ) wsp_len = len(word_speaker_mapping) words_list, speaker_list = [], [] for k, line_dict in enumerate(word_speaker_mapping): word, speaker = line_dict["word"], line_dict["speaker"] words_list.append(word) speaker_list.append(speaker) k = 0 while k < len(word_speaker_mapping): line_dict = word_speaker_mapping[k] if ( k < wsp_len - 1 and speaker_list[k] != speaker_list[k + 1] and not is_word_sentence_end(k) ): left_idx = get_first_word_idx_of_sentence( k, words_list, speaker_list, max_words_in_sentence ) right_idx = ( get_last_word_idx_of_sentence( k, words_list, max_words_in_sentence - k + left_idx - 1 ) if left_idx > -1 else -1 ) if min(left_idx, right_idx) == -1: k += 1 continue spk_labels = speaker_list[left_idx : right_idx + 1] mod_speaker = max(set(spk_labels), key=spk_labels.count) if spk_labels.count(mod_speaker) < len(spk_labels) // 2: k += 1 continue speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( right_idx - left_idx + 1 ) k = right_idx k += 1 k, realigned_list = 0, [] while k < len(word_speaker_mapping): line_dict = word_speaker_mapping[k].copy() line_dict["speaker"] = speaker_list[k] realigned_list.append(line_dict) k += 1 return realigned_list