diff --git a/TTS/bin/remove_silence_using_vad.py b/TTS/bin/remove_silence_using_vad.py index a8a60bf8..7d88ae91 100755 --- a/TTS/bin/remove_silence_using_vad.py +++ b/TTS/bin/remove_silence_using_vad.py @@ -4,6 +4,7 @@ import os import pathlib from tqdm import tqdm + from TTS.utils.vad import get_vad_model_and_utils, remove_silence @@ -16,7 +17,13 @@ def adjust_path_and_remove_silence(audio_path): # create all directory structure pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) # remove the silence and save the audio - output_path = remove_silence(model_and_utils, audio_path, output_path, trim_just_beginning_and_end=args.trim_just_beginning_and_end, use_cuda=args.use_cuda) + output_path = remove_silence( + model_and_utils, + audio_path, + output_path, + trim_just_beginning_and_end=args.trim_just_beginning_and_end, + use_cuda=args.use_cuda, + ) return output_path diff --git a/TTS/utils/vad.py b/TTS/utils/vad.py index 7384934a..033b911a 100644 --- a/TTS/utils/vad.py +++ b/TTS/utils/vad.py @@ -1,6 +1,7 @@ import torch import torchaudio + def read_audio(path): wav, sr = torchaudio.load(path) @@ -9,39 +10,42 @@ def read_audio(path): return wav.squeeze(0), sr + def resample_wav(wav, sr, new_sr): wav = wav.unsqueeze(0) transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr) wav = transform(wav) return wav.squeeze(0) + def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False): factor = new_sr / vad_sr new_timestamps = [] if just_begging_end and timestamps: # get just the start and end timestamps - new_dict = {'start': int(timestamps[0]['start']*factor), 'end': int(timestamps[-1]['end']*factor)} + new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)} new_timestamps.append(new_dict) else: for ts in timestamps: # map to the new SR - new_dict = {'start': int(ts['start']*factor), 'end': int(ts['end']*factor)} + new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)} new_timestamps.append(new_dict) return new_timestamps + def get_vad_model_and_utils(use_cuda=False): - model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', - model='silero_vad', - force_reload=True, - onnx=False) + model, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=False) if use_cuda: model = model.cuda() get_speech_timestamps, save_audio, _, _, collect_chunks = utils return model, get_speech_timestamps, save_audio, collect_chunks -def remove_silence(model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False): + +def remove_silence( + model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False +): # get the VAD model and utils functions model, get_speech_timestamps, save_audio, collect_chunks = model_and_utils @@ -62,7 +66,9 @@ def remove_silence(model_and_utils, audio_path, out_path, vad_sample_rate=8000, speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768) # map the current speech_timestamps to the sample rate of the ground truth audio - new_speech_timestamps = map_timestamps_to_new_sr(vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end) + new_speech_timestamps = map_timestamps_to_new_sr( + vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end + ) # if have speech timestamps else save the wav if new_speech_timestamps: