# src/multivoice/lib/stt_process_audio.py
import faster_whisper
import logging
import os
import torch
import torchaudio
import multivoice.identify_voice
from multivoice.lib.stt_align_transcription import align_transcription
from multivoice.lib.stt_diarize_audio import diarize_audio
from multivoice.lib.stt_langs import process_language_arg
from multivoice.lib.stt_punctuation import get_realigned_ws_mapping_with_punctuation
from multivoice.lib.stt_punctuation import restore_punctuation
from multivoice.lib.stt_separate_vocals import separate_vocals
from multivoice.lib.stt_transcribe_audio import transcribe_audio
from multivoice.lib.stt_speaker_mapping import (
get_sentences_speaker_mapping,
get_words_speaker_mapping,
)
[docs]
def process_audio(args, temp_path):
"""
Process the input audio to generate speaker-segment mapping with timestamps and transcriptions.
Args:
args (Namespace): Command line arguments parsed by argparse.
temp_path (str): Temporary directory path for intermediate files.
Returns:
list: A list of dictionaries, each representing a sentence segment with its start time,
end time, transcription, and speaker ID.
"""
# Define the model types based on the device
mtypes = {"cpu": "int8", "cuda": "float16"}
# Determine the language for transcription based on input arguments
language = process_language_arg(args.language, args.model_name)
# Optionally separate vocals from the audio if stemming is enabled
vocal_target = None
if args.stemming:
vocal_target = separate_vocals(args.audio, temp_path, args.device)
else:
vocal_target = args.audio # Use the original audio file if stemming is disabled
# Transcribe the vocal target to get a full transcript and additional information
full_transcript, info = transcribe_audio(
vocal_target,
language,
args.suppress_numerals,
args.batch_size,
args.device,
mtypes,
args,
)
# Decode audio waveform from the vocal target file
audio_waveform = faster_whisper.decode_audio(vocal_target)
# Perform forced alignment to get word-level timestamps
word_timestamps = align_transcription(
full_transcript, audio_waveform, args.device, args.batch_size, info
)
# Convert audio waveform to mono for compatibility with NeMo models
torchaudio.save(
os.path.join(temp_path, "mono_file.wav"),
torch.from_numpy(audio_waveform).unsqueeze(0).float(),
16000,
channels_first=True,
)
# Diarize the audio to get speaker timestamps
speaker_ts = diarize_audio(temp_path, args.device)
# Map words to speakers based on start times of words and speaker timestamps
wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")
# Restore punctuation in the word-segment mapping using additional information from transcription
wsm = restore_punctuation(wsm, info)
# Realign word-segment mapping with punctuation for better accuracy
wsm = get_realigned_ws_mapping_with_punctuation(wsm)
# Map sentences to speakers based on aligned words and speaker timestamps
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
# Create speaker directories
speaker_dirs = {}
for sentence in ssm:
speaker_id = f"Speaker_{sentence['speaker']}"
if speaker_id not in speaker_dirs:
# Find the directory name for this speaker based on their ID
speaker_dir_name = next(
(item[2] for item in speaker_ts if item[2] == sentence["speaker"]), None
)
if speaker_dir_name is not None:
speaker_dir = os.path.join(temp_path, str(speaker_dir_name))
os.makedirs(speaker_dir, exist_ok=True)
speaker_dirs[sentence["speaker"]] = speaker_dir
# Save each segment and organize by speaker
for idx, sentence in enumerate(ssm):
start_time, end_time = sentence["start_time"], sentence["end_time"]
start_idx = int(start_time * 16000 / 1000)
end_idx = int(end_time * 16000 / 1000)
segment_waveform = (
torch.from_numpy(audio_waveform[start_idx:end_idx]).unsqueeze(0).float()
)
# Ensure speaker_id is correctly formatted
speaker_id = sentence["speaker"]
if speaker_id not in speaker_dirs:
# Create the directory for this speaker based on their ID
speaker_dir = os.path.join(temp_path, f"Speaker_{sentence['speaker']}")
os.makedirs(speaker_dir, exist_ok=True)
speaker_dirs[speaker_id] = speaker_dir
# Save the segment in the respective speaker directory
output_segment_path = os.path.join(
speaker_dirs[speaker_id], f"segment_{idx+1}.wav"
)
torchaudio.save(output_segment_path, segment_waveform, 16000)
# Combine WAV files per speaker
combined_wav_paths = {}
for speaker_id, speaker_dir in speaker_dirs.items():
combined_wav_path = os.path.join(temp_path, f"combined_{speaker_id}.wav")
combine_wav_files(speaker_dir, combined_wav_path)
combined_wav_paths[speaker_id] = combined_wav_path
# Identify voice for each combined WAV file
speaker_ids_mapping = {}
for speaker_id, combined_wav_path in combined_wav_paths.items():
identifier = multivoice.identify_voice.SpeakerVerifierCLI()
voice_id = identifier.identify_voice(combined_wav_path)
logging.debug(
"Voice ID: %s, Combined Wav Path: %s", voice_id, combined_wav_path
)
# Update speaker in sentence mapping if a valid voice ID is obtained
if voice_id is not None:
speaker_ids_mapping[speaker_id] = voice_id
# Log the final speaker-segment mapping
logging.debug("SSM: %s", ssm)
# Replace speaker IDs with identified voice names in the SSM
for sentence in ssm:
spk_id = sentence["speaker"]
if spk_id in speaker_ids_mapping:
sentence["speaker"] = speaker_ids_mapping[spk_id]
return ssm
[docs]
def combine_wav_files(input_dir, output_path):
"""
Combine all WAV files in a directory into a single WAV file using sox.
Args:
input_dir (str): Directory containing WAV files to be combined.
output_path (str): Output path for the combined WAV file.
"""
logging.debug("Combining WAV files in %s into %s", input_dir, output_path)
wav_files = [
os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".wav")
]
if not wav_files:
logging.warning(f"No WAV files found in {input_dir}")
return
# Use sox to combine the WAV files
quoted_wav_files = [f'"{file}"' for file in wav_files]
command = f'sox {" ".join(quoted_wav_files)} "{output_path}"'
return_code = os.system(command)
if return_code != 0:
logging.error("Failed to combine WAV files using sox. Command: %s", command)