From b1e3160884ec8c2bf25832c13de2266544d9ae10 Mon Sep 17 00:00:00 2001 From: Branislav Gerazov Date: Thu, 4 Feb 2021 09:52:03 +0100 Subject: [PATCH] waveRNN fix --- TTS/bin/train_vocoder_wavernn.py | 25 +++++++--- TTS/vocoder/configs/wavernn_config.json | 3 ++ TTS/vocoder/models/wavernn.py | 9 ++-- TTS/vocoder/utils/generic_utils.py | 65 +++++++++++++------------ 4 files changed, 59 insertions(+), 43 deletions(-) diff --git a/TTS/bin/train_vocoder_wavernn.py b/TTS/bin/train_vocoder_wavernn.py index 6847e011..d38bdee5 100644 --- a/TTS/bin/train_vocoder_wavernn.py +++ b/TTS/bin/train_vocoder_wavernn.py @@ -32,7 +32,7 @@ from TTS.vocoder.datasets.preprocess import ( load_wav_feat_data ) from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss -from TTS.vocoder.utils.generic_utils import setup_wavernn +from TTS.vocoder.utils.generic_utils import setup_generator from TTS.vocoder.utils.io import save_best_model, save_checkpoint @@ -200,9 +200,14 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) - sample_wav = model.inference(ground_mel, c.batched, - c.target_samples, c.overlap_samples, - use_cuda) + ground_mel = torch.FloatTensor(ground_mel) + if use_cuda: + ground_mel = ground_mel.cuda(non_blocking=True) + sample_wav = model.inference(ground_mel, + c.batched, + c.target_samples, + c.overlap_samples, + ) predict_mel = ap.melspectrogram(sample_wav) # compute spectrograms @@ -284,8 +289,14 @@ def evaluate(model, criterion, ap, global_step, epoch): eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) - sample_wav = model.inference(ground_mel, c.batched, c.target_samples, - c.overlap_samples, use_cuda) + ground_mel = torch.FloatTensor(ground_mel) + if use_cuda: + ground_mel = ground_mel.cuda(non_blocking=True) + sample_wav = model.inference(ground_mel, + c.batched, + c.target_samples, + c.overlap_samples, + ) predict_mel = ap.melspectrogram(sample_wav) # Sample audio @@ -343,7 +354,7 @@ def main(args): # pylint: disable=redefined-outer-name eval_data, train_data = load_wav_data( c.data_path, c.eval_split_size) # setup model - model_wavernn = setup_wavernn(c) + model_wavernn = setup_generator(c) # setup amp scaler scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None diff --git a/TTS/vocoder/configs/wavernn_config.json b/TTS/vocoder/configs/wavernn_config.json index 58667b69..effb103b 100644 --- a/TTS/vocoder/configs/wavernn_config.json +++ b/TTS/vocoder/configs/wavernn_config.json @@ -56,6 +56,9 @@ "upsample_factors": [4, 8, 8] // this needs to correctly factorise hop_length }, +// GENERATOR - for backward compatibility + "generator_model": "WaveRNN", + // DATASET //"use_gta": true, // use computed gta features from the tts model "data_path": "/home/erogol/Data/libritts/LibriTTS/train-clean-360/", // path containing training wav files diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index cb03deb3..fdb71cff 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -260,7 +260,7 @@ class WaveRNN(nn.Module): x = F.relu(self.fc2(x)) return self.fc3(x) - def inference(self, mels, batched, target, overlap): + def inference(self, mels, batched=None, target=None, overlap=None): self.eval() device = mels.device @@ -350,10 +350,11 @@ class WaveRNN(nn.Module): self.gen_display(i, seq_len, b_size, start) output = torch.stack(output).transpose(0, 1) - output = output.cpu().numpy() - output = output.astype(np.float64) - + output = output.cpu() if batched: + output = output.numpy() + output = output.astype(np.float64) + output = self.xfade_and_unfold(output, target, overlap) else: output = output[0] diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index fb943a37..b43a1263 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -61,40 +61,37 @@ def plot_results(y_hat, y, ap, global_step, name_prefix): return figures -def to_camel(text): +def to_camel(text, cap=True): text = text.capitalize() return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) -def setup_wavernn(c): - print(" > Model: WaveRNN") - MyModel = importlib.import_module("TTS.vocoder.models.wavernn") - MyModel = getattr(MyModel, "WaveRNN") - model = MyModel( - rnn_dims=c.wavernn_model_params['rnn_dims'], - fc_dims=c.wavernn_model_params['fc_dims'], - mode=c.mode, - mulaw=c.mulaw, - pad=c.padding, - use_aux_net=c.wavernn_model_params['use_aux_net'], - use_upsample_net=c.wavernn_model_params['use_upsample_net'], - upsample_factors=c.wavernn_model_params['upsample_factors'], - feat_dims=c.audio['num_mels'], - compute_dims=c.wavernn_model_params['compute_dims'], - res_out_dims=c.wavernn_model_params['res_out_dims'], - num_res_blocks=c.wavernn_model_params['num_res_blocks'], - hop_length=c.audio["hop_length"], - sample_rate=c.audio["sample_rate"], - ) - return model - - def setup_generator(c): print(" > Generator Model: {}".format(c.generator_model)) MyModel = importlib.import_module('TTS.vocoder.models.' + c.generator_model.lower()) - MyModel = getattr(MyModel, to_camel(c.generator_model)) - if c.generator_model.lower() in 'melgan_generator': + # this is to preserve the WaveRNN class name (instead of Wavernn) + if c.generator_model != 'WaveRNN': + MyModel = getattr(MyModel, to_camel(c.generator_model)) + else: + MyModel = getattr(MyModel, c.generator_model) + if c.generator_model.lower() in 'wavernn': + model = MyModel( + rnn_dims=c.wavernn_model_params['rnn_dims'], + fc_dims=c.wavernn_model_params['fc_dims'], + mode=c.mode, + mulaw=c.mulaw, + pad=c.padding, + use_aux_net=c.wavernn_model_params['use_aux_net'], + use_upsample_net=c.wavernn_model_params['use_upsample_net'], + upsample_factors=c.wavernn_model_params['upsample_factors'], + feat_dims=c.audio['num_mels'], + compute_dims=c.wavernn_model_params['compute_dims'], + res_out_dims=c.wavernn_model_params['res_out_dims'], + num_res_blocks=c.wavernn_model_params['num_res_blocks'], + hop_length=c.audio["hop_length"], + sample_rate=c.audio["sample_rate"],) + elif c.generator_model.lower() in 'melgan_generator': model = MyModel( in_channels=c.audio['num_mels'], out_channels=1, @@ -103,9 +100,10 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) - if c.generator_model in 'melgan_fb_generator': - pass - if c.generator_model.lower() in 'multiband_melgan_generator': + elif c.generator_model in 'melgan_fb_generator': + raise ValueError( + 'melgan_fb_generator is now fullband_melgan_generator') + elif c.generator_model.lower() in 'multiband_melgan_generator': model = MyModel( in_channels=c.audio['num_mels'], out_channels=4, @@ -114,7 +112,7 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) - if c.generator_model.lower() in 'fullband_melgan_generator': + elif c.generator_model.lower() in 'fullband_melgan_generator': model = MyModel( in_channels=c.audio['num_mels'], out_channels=1, @@ -123,7 +121,7 @@ def setup_generator(c): upsample_factors=c.generator_model_params['upsample_factors'], res_kernel=3, num_res_blocks=c.generator_model_params['num_res_blocks']) - if c.generator_model.lower() in 'parallel_wavegan_generator': + elif c.generator_model.lower() in 'parallel_wavegan_generator': model = MyModel( in_channels=1, out_channels=1, @@ -138,7 +136,7 @@ def setup_generator(c): bias=True, use_weight_norm=True, upsample_factors=c.generator_model_params['upsample_factors']) - if c.generator_model.lower() in 'wavegrad': + elif c.generator_model.lower() in 'wavegrad': model = MyModel( in_channels=c['audio']['num_mels'], out_channels=1, @@ -149,6 +147,9 @@ def setup_generator(c): ublock_out_channels=c['model_params']['ublock_out_channels'], upsample_factors=c['model_params']['upsample_factors'], upsample_dilations=c['model_params']['upsample_dilations']) + else: + raise NotImplementedError( + f'Model {c.generator_model} not implemented!') return model