# src/multivoice/lib/stt_diarize_audio.py
import logging
import os
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
from multivoice.lib.stt_config import create_config
[docs]
def diarize_audio(temp_path, device):
"""
Diarizes the audio to identify speaker turns.
Args:
temp_path (str): The path where temporary files will be stored.
device (str): The device to use for processing ('cpu' or 'cuda').
Returns:
list: A list of lists containing start time, end time, and speaker ID for each identified segment.
"""
logging.debug("Diarizing the audio...")
logging.debug("temp_path: %s", temp_path)
# Initialize NeMo MSDD diarization model with configuration created from the temporary path
msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(device)
msdd_model.diarize() # Perform diarization
speaker_ts = []
# Read the output RTTM file to extract speaker segments
with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
lines = f.readlines()
for line in lines:
line_list = line.split(" ")
s = int(float(line_list[5]) * 1000) # Convert start time to milliseconds
e = s + int(
float(line_list[8]) * 1000
) # Calculate end time in milliseconds
speaker_ts.append(
[s, e, int(line_list[11].split("_")[-1])]
) # Append [start_time, end_time, speaker_id]
return speaker_ts