# src/multivoice/lib/stt_punctuation.py
import logging
import re
from deepmultilingualpunctuation import PunctuationModel
from multivoice.lib.stt_langs import punct_model_langs
[docs]
def restore_punctuation(wsm, info):
"""
Restores punctuation to a list of words based on the detected language.
Args:
wsm (list): A list of dictionaries containing word and speaker information.
info (object): An object containing information about the language.
Returns:
list: The modified list of dictionaries with restored punctuation.
"""
logging.debug("Restoring punctuation...")
if info.language in punct_model_langs:
# Initialize PunctuationModel with a specific model for punctuation restoration
punct_model = PunctuationModel(model="kredor/punctuate-all")
# Extract words from the word_speaker_mapping list
words_list = list(map(lambda x: x["word"], wsm))
# Get predictions for each word on where to place punctuation
labled_words = punct_model.predict(words_list, chunk_size=230)
ending_puncts = ".?!"
model_puncts = ".,;:!?"
def is_acronym(x):
"""
Checks if a given string is an acronym.
Args:
x (str): The word to check.
Returns:
bool: True if the word is an acronym, False otherwise.
"""
return re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
# Iterate over each word and its corresponding label from the prediction
for word_dict, labeled_tuple in zip(wsm, labled_words):
word = word_dict["word"]
# Add punctuation if it is not already present or if it's an acronym
if (
word
and labeled_tuple[1] in ending_puncts
and (word[-1] not in model_puncts or is_acronym(word))
):
word += labeled_tuple[1]
# Avoid double punctuation by stripping out extra dots
if word.endswith(".."):
word = word.rstrip(".")
word_dict["word"] = word
else:
logging.warning(
f"Punctuation restoration is not available for {info.language} language."
" Using the original punctuation."
)
return wsm
sentence_ending_punctuations = ".?!" # Define sentence-ending punctuations
[docs]
def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words):
"""
Finds the index of the first word in a sentence based on speaker continuity and punctuation.
Args:
word_idx (int): The current word index.
word_list (list): A list of words.
speaker_list (list): A list of speakers corresponding to each word.
max_words (int): Maximum number of words to consider as part of the same sentence.
Returns:
int: Index of the first word in the sentence or -1 if not found.
"""
def is_word_sentence_end(x):
"""
Checks if a word at index x is a sentence-ending punctuation.
Args:
x (int): The index of the word to check.
Returns:
bool: True if the word ends with a sentence-ending punctuation, False otherwise.
"""
return x >= 0 and word_list[x][-1] in sentence_ending_punctuations
left_idx = word_idx
while (
left_idx > 0
and word_idx - left_idx < max_words
and speaker_list[left_idx - 1] == speaker_list[left_idx]
and not is_word_sentence_end(left_idx - 1)
):
left_idx -= 1
return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1
[docs]
def get_last_word_idx_of_sentence(word_idx, word_list, max_words):
"""
Finds the index of the last word in a sentence based on punctuation.
Args:
word_idx (int): The current word index.
word_list (list): A list of words.
max_words (int): Maximum number of words to consider as part of the same sentence.
Returns:
int: Index of the last word in the sentence or -1 if not found.
"""
def is_word_sentence_end(x):
"""
Checks if a word at index x is a sentence-ending punctuation.
Args:
x (int): The index of the word to check.
Returns:
bool: True if the word ends with a sentence-ending punctuation, False otherwise.
"""
return x >= 0 and word_list[x][-1] in sentence_ending_punctuations
right_idx = word_idx
while (
right_idx < len(word_list) - 1
and right_idx - word_idx < max_words
and not is_word_sentence_end(right_idx)
):
right_idx += 1
return (
right_idx
if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx)
else -1
)
[docs]
def get_realigned_ws_mapping_with_punctuation(
word_speaker_mapping, max_words_in_sentence=50
):
"""
Realigns the speaker mapping with punctuation to ensure consistent speaker labels for sentences.
Args:
word_speaker_mapping (list): A list of dictionaries containing words and their corresponding speakers.
max_words_in_sentence (int): Maximum number of words considered in a sentence for alignment.
Returns:
list: The realigned list of dictionaries with updated speaker information.
"""
def is_word_sentence_end(x):
"""
Checks if a word at index x is a sentence-ending punctuation.
Args:
x (int): The index of the word to check.
Returns:
bool: True if the word ends with a sentence-ending punctuation, False otherwise.
"""
return (
x >= 0
and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations
)
wsp_len = len(word_speaker_mapping)
words_list, speaker_list = [], []
for k, line_dict in enumerate(word_speaker_mapping):
word, speaker = line_dict["word"], line_dict["speaker"]
words_list.append(word)
speaker_list.append(speaker)
k = 0
while k < len(word_speaker_mapping):
line_dict = word_speaker_mapping[k]
if (
k < wsp_len - 1
and speaker_list[k] != speaker_list[k + 1]
and not is_word_sentence_end(k)
):
left_idx = get_first_word_idx_of_sentence(
k, words_list, speaker_list, max_words_in_sentence
)
right_idx = (
get_last_word_idx_of_sentence(
k, words_list, max_words_in_sentence - k + left_idx - 1
)
if left_idx > -1
else -1
)
if min(left_idx, right_idx) == -1:
k += 1
continue
spk_labels = speaker_list[left_idx : right_idx + 1]
mod_speaker = max(set(spk_labels), key=spk_labels.count)
if spk_labels.count(mod_speaker) < len(spk_labels) // 2:
k += 1
continue
speaker_list[left_idx : right_idx + 1] = [mod_speaker] * (
right_idx - left_idx + 1
)
k = right_idx
k += 1
k, realigned_list = 0, []
while k < len(word_speaker_mapping):
line_dict = word_speaker_mapping[k].copy()
line_dict["speaker"] = speaker_list[k]
realigned_list.append(line_dict)
k += 1
return realigned_list