mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #1431 from coqui-ai/silero-vad
Replace webrtcvad by silero-vad
This commit is contained in:
commit
464dc658ff
|
@ -1,51 +1,31 @@
|
|||
import argparse
|
||||
import glob
|
||||
import multiprocessing
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.vad import get_vad_speech_segments, read_wave, write_wave
|
||||
from TTS.utils.vad import get_vad_model_and_utils, remove_silence
|
||||
|
||||
|
||||
def remove_silence(filepath):
|
||||
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||
def adjust_path_and_remove_silence(audio_path):
|
||||
output_path = audio_path.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||
# ignore if the file exists
|
||||
if os.path.exists(output_path) and not args.force:
|
||||
return
|
||||
return output_path
|
||||
|
||||
# create all directory structure
|
||||
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
# load wave
|
||||
audio, sample_rate = read_wave(filepath)
|
||||
# remove the silence and save the audio
|
||||
output_path = remove_silence(
|
||||
model_and_utils,
|
||||
audio_path,
|
||||
output_path,
|
||||
trim_just_beginning_and_end=args.trim_just_beginning_and_end,
|
||||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
# get speech segments
|
||||
segments = get_vad_speech_segments(audio, sample_rate, aggressiveness=args.aggressiveness)
|
||||
|
||||
segments = list(segments)
|
||||
num_segments = len(segments)
|
||||
flag = False
|
||||
# create the output wave
|
||||
if num_segments != 0:
|
||||
for i, segment in reversed(list(enumerate(segments))):
|
||||
if i >= 1:
|
||||
if not flag:
|
||||
concat_segment = segment
|
||||
flag = True
|
||||
else:
|
||||
concat_segment = segment + concat_segment
|
||||
else:
|
||||
if flag:
|
||||
segment = segment + concat_segment
|
||||
# print("Saving: ", output_path)
|
||||
write_wave(output_path, segment, sample_rate)
|
||||
return
|
||||
else:
|
||||
print("> Just Copying the file to:", output_path)
|
||||
# if fail to remove silence just write the file
|
||||
write_wave(output_path, audio, sample_rate)
|
||||
return
|
||||
return output_path
|
||||
|
||||
|
||||
def preprocess_audios():
|
||||
|
@ -54,17 +34,24 @@ def preprocess_audios():
|
|||
if not args.force:
|
||||
print("> Ignoring files that already exist in the output directory.")
|
||||
|
||||
if args.trim_just_beginning_and_end:
|
||||
print("> Trimming just the beginning and the end with nonspeech parts.")
|
||||
else:
|
||||
print("> Trimming all nonspeech parts.")
|
||||
|
||||
if files:
|
||||
# create threads
|
||||
num_threads = multiprocessing.cpu_count()
|
||||
process_map(remove_silence, files, max_workers=num_threads, chunksize=15)
|
||||
# num_threads = multiprocessing.cpu_count()
|
||||
# process_map(adjust_path_and_remove_silence, files, max_workers=num_threads, chunksize=15)
|
||||
for f in tqdm(files):
|
||||
adjust_path_and_remove_silence(f)
|
||||
else:
|
||||
print("> No files Found !")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2"
|
||||
description="python TTS/bin/remove_silence_using_vad.py -i=VCTK-Corpus/ -o=VCTK-Corpus-removed-silence/ -g=wav48_silence_trimmed/*/*_mic1.flac --trim_just_beginning_and_end True"
|
||||
)
|
||||
parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
|
||||
parser.add_argument(
|
||||
|
@ -79,11 +66,20 @@ if __name__ == "__main__":
|
|||
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--aggressiveness",
|
||||
type=int,
|
||||
default=2,
|
||||
help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
|
||||
"-t",
|
||||
"--trim_just_beginning_and_end",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="If True this script will trim just the beginning and end nonspeech parts. If False all nonspeech parts will be trim. Default True",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--use_cuda",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="If True use cuda",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
# load the model and utils
|
||||
model_and_utils = get_vad_model_and_utils(use_cuda=args.use_cuda)
|
||||
preprocess_audios()
|
||||
|
|
183
TTS/utils/vad.py
183
TTS/utils/vad.py
|
@ -1,144 +1,81 @@
|
|||
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
||||
import collections
|
||||
import contextlib
|
||||
import wave
|
||||
|
||||
import webrtcvad
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
def read_wave(path):
|
||||
"""Reads a .wav file.
|
||||
def read_audio(path):
|
||||
wav, sr = torchaudio.load(path)
|
||||
|
||||
Takes the path, and returns (PCM audio data, sample rate).
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, "rb")) as wf:
|
||||
num_channels = wf.getnchannels()
|
||||
assert num_channels == 1
|
||||
sample_width = wf.getsampwidth()
|
||||
assert sample_width == 2
|
||||
sample_rate = wf.getframerate()
|
||||
assert sample_rate in (8000, 16000, 32000, 48000)
|
||||
pcm_data = wf.readframes(wf.getnframes())
|
||||
return pcm_data, sample_rate
|
||||
if wav.size(0) > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
return wav.squeeze(0), sr
|
||||
|
||||
|
||||
def write_wave(path, audio, sample_rate):
|
||||
"""Writes a .wav file.
|
||||
|
||||
Takes path, PCM audio data, and sample rate.
|
||||
"""
|
||||
with contextlib.closing(wave.open(path, "wb")) as wf:
|
||||
wf.setnchannels(1)
|
||||
wf.setsampwidth(2)
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(audio)
|
||||
def resample_wav(wav, sr, new_sr):
|
||||
wav = wav.unsqueeze(0)
|
||||
transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr)
|
||||
wav = transform(wav)
|
||||
return wav.squeeze(0)
|
||||
|
||||
|
||||
class Frame(object):
|
||||
"""Represents a "frame" of audio data."""
|
||||
def map_timestamps_to_new_sr(vad_sr, new_sr, timestamps, just_begging_end=False):
|
||||
factor = new_sr / vad_sr
|
||||
new_timestamps = []
|
||||
if just_begging_end and timestamps:
|
||||
# get just the start and end timestamps
|
||||
new_dict = {"start": int(timestamps[0]["start"] * factor), "end": int(timestamps[-1]["end"] * factor)}
|
||||
new_timestamps.append(new_dict)
|
||||
else:
|
||||
for ts in timestamps:
|
||||
# map to the new SR
|
||||
new_dict = {"start": int(ts["start"] * factor), "end": int(ts["end"] * factor)}
|
||||
new_timestamps.append(new_dict)
|
||||
|
||||
def __init__(self, _bytes, timestamp, duration):
|
||||
self.bytes = _bytes
|
||||
self.timestamp = timestamp
|
||||
self.duration = duration
|
||||
return new_timestamps
|
||||
|
||||
|
||||
def frame_generator(frame_duration_ms, audio, sample_rate):
|
||||
"""Generates audio frames from PCM audio data.
|
||||
def get_vad_model_and_utils(use_cuda=False):
|
||||
model, utils = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=True, onnx=False)
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
|
||||
Takes the desired frame duration in milliseconds, the PCM data, and
|
||||
the sample rate.
|
||||
|
||||
Yields Frames of the requested duration.
|
||||
"""
|
||||
n = int(sample_rate * (frame_duration_ms / 1000.0) * 2)
|
||||
offset = 0
|
||||
timestamp = 0.0
|
||||
duration = (float(n) / sample_rate) / 2.0
|
||||
while offset + n < len(audio):
|
||||
yield Frame(audio[offset : offset + n], timestamp, duration)
|
||||
timestamp += duration
|
||||
offset += n
|
||||
get_speech_timestamps, save_audio, _, _, collect_chunks = utils
|
||||
return model, get_speech_timestamps, save_audio, collect_chunks
|
||||
|
||||
|
||||
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
|
||||
"""Filters out non-voiced audio frames.
|
||||
def remove_silence(
|
||||
model_and_utils, audio_path, out_path, vad_sample_rate=8000, trim_just_beginning_and_end=True, use_cuda=False
|
||||
):
|
||||
|
||||
Given a webrtcvad.Vad and a source of audio frames, yields only
|
||||
the voiced audio.
|
||||
# get the VAD model and utils functions
|
||||
model, get_speech_timestamps, save_audio, collect_chunks = model_and_utils
|
||||
|
||||
Uses a padded, sliding window algorithm over the audio frames.
|
||||
When more than 90% of the frames in the window are voiced (as
|
||||
reported by the VAD), the collector triggers and begins yielding
|
||||
audio frames. Then the collector waits until 90% of the frames in
|
||||
the window are unvoiced to detrigger.
|
||||
# read ground truth wav and resample the audio for the VAD
|
||||
wav, gt_sample_rate = read_audio(audio_path)
|
||||
|
||||
The window is padded at the front and back to provide a small
|
||||
amount of silence or the beginnings/endings of speech around the
|
||||
voiced frames.
|
||||
# if needed, resample the audio for the VAD model
|
||||
if gt_sample_rate != vad_sample_rate:
|
||||
wav_vad = resample_wav(wav, gt_sample_rate, vad_sample_rate)
|
||||
else:
|
||||
wav_vad = wav
|
||||
|
||||
Arguments:
|
||||
if use_cuda:
|
||||
wav_vad = wav_vad.cuda()
|
||||
|
||||
sample_rate - The audio sample rate, in Hz.
|
||||
frame_duration_ms - The frame duration in milliseconds.
|
||||
padding_duration_ms - The amount to pad the window, in milliseconds.
|
||||
vad - An instance of webrtcvad.Vad.
|
||||
frames - a source of audio frames (sequence or generator).
|
||||
# get speech timestamps from full audio file
|
||||
speech_timestamps = get_speech_timestamps(wav_vad, model, sampling_rate=vad_sample_rate, window_size_samples=768)
|
||||
|
||||
Returns: A generator that yields PCM audio data.
|
||||
"""
|
||||
num_padding_frames = int(padding_duration_ms / frame_duration_ms)
|
||||
# We use a deque for our sliding window/ring buffer.
|
||||
ring_buffer = collections.deque(maxlen=num_padding_frames)
|
||||
# We have two states: TRIGGERED and NOTTRIGGERED. We start in the
|
||||
# NOTTRIGGERED state.
|
||||
triggered = False
|
||||
# map the current speech_timestamps to the sample rate of the ground truth audio
|
||||
new_speech_timestamps = map_timestamps_to_new_sr(
|
||||
vad_sample_rate, gt_sample_rate, speech_timestamps, trim_just_beginning_and_end
|
||||
)
|
||||
|
||||
voiced_frames = []
|
||||
for frame in frames:
|
||||
is_speech = vad.is_speech(frame.bytes, sample_rate)
|
||||
# if have speech timestamps else save the wav
|
||||
if new_speech_timestamps:
|
||||
wav = collect_chunks(new_speech_timestamps, wav)
|
||||
else:
|
||||
print(f"> The file {audio_path} probably does not have speech please check it !!")
|
||||
|
||||
# sys.stdout.write('1' if is_speech else '0')
|
||||
if not triggered:
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_voiced = len([f for f, speech in ring_buffer if speech])
|
||||
# If we're NOTTRIGGERED and more than 90% of the frames in
|
||||
# the ring buffer are voiced frames, then enter the
|
||||
# TRIGGERED state.
|
||||
if num_voiced > 0.9 * ring_buffer.maxlen:
|
||||
triggered = True
|
||||
# sys.stdout.write('+(%s)' % (ring_buffer[0][0].timestamp,))
|
||||
# We want to yield all the audio we see from now until
|
||||
# we are NOTTRIGGERED, but we have to start with the
|
||||
# audio that's already in the ring buffer.
|
||||
for f, _ in ring_buffer:
|
||||
voiced_frames.append(f)
|
||||
ring_buffer.clear()
|
||||
else:
|
||||
# We're in the TRIGGERED state, so collect the audio data
|
||||
# and add it to the ring buffer.
|
||||
voiced_frames.append(frame)
|
||||
ring_buffer.append((frame, is_speech))
|
||||
num_unvoiced = len([f for f, speech in ring_buffer if not speech])
|
||||
# If more than 90% of the frames in the ring buffer are
|
||||
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||
# audio we've collected.
|
||||
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
||||
# sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
|
||||
triggered = False
|
||||
yield b"".join([f.bytes for f in voiced_frames])
|
||||
ring_buffer.clear()
|
||||
voiced_frames = []
|
||||
# If we have any leftover voiced audio when we run out of input,
|
||||
# yield it.
|
||||
if voiced_frames:
|
||||
yield b"".join([f.bytes for f in voiced_frames])
|
||||
|
||||
|
||||
def get_vad_speech_segments(audio, sample_rate, aggressiveness=2, padding_duration_ms=300):
|
||||
|
||||
vad = webrtcvad.Vad(int(aggressiveness))
|
||||
frames = list(frame_generator(30, audio, sample_rate))
|
||||
segments = vad_collector(sample_rate, 30, padding_duration_ms, vad, frames)
|
||||
|
||||
return segments
|
||||
# save audio
|
||||
save_audio(out_path, wav, sampling_rate=gt_sample_rate)
|
||||
return out_path
|
||||
|
|
|
@ -34,5 +34,3 @@ mecab-python3==1.0.3
|
|||
unidic-lite==1.0.8
|
||||
# gruut+supported langs
|
||||
gruut[cs,de,es,fr,it,nl,pt,ru,sv]==2.2.3
|
||||
# others
|
||||
webrtcvad # for VAD
|
||||
|
|
Loading…
Reference in New Issue