diff --git a/TTS/bin/train_wavernn_vocoder.py b/TTS/bin/train_wavernn_vocoder.py index 533fe0ce..78984510 100644 --- a/TTS/bin/train_wavernn_vocoder.py +++ b/TTS/bin/train_wavernn_vocoder.py @@ -1,8 +1,5 @@ import argparse -import math import os -import pickle -import shutil import sys import traceback import time @@ -11,7 +8,8 @@ import random import torch from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler + +# from torch.utils.data.distributed import DistributedSampler from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor @@ -30,7 +28,6 @@ from TTS.utils.generic_utils import ( ) from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.datasets.preprocess import ( - load_wav_data, find_feat_files, load_wav_feat_data, preprocess_wav_files, @@ -322,7 +319,7 @@ def main(args): # pylint: disable=redefined-outer-name CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size ) else: - print(f" > No feature data found. Preprocessing...") + print(" > No feature data found. Preprocessing...") # preprocessing feature data from given wav files preprocess_wav_files(OUT_PATH, CONFIG, ap) eval_data, train_data = load_wav_feat_data( diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index 1b0a8077..5d5b9f15 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -1,5 +1,3 @@ -import os -import glob import torch import numpy as np from torch.utils.data import Dataset @@ -42,7 +40,7 @@ class WaveRNNDataset(Dataset): wavpath, feat_path = self.item_list[index] m = np.load(feat_path.replace("/quant/", "/mel/")) # x = self.wav_cache[index] - if 5 > m.shape[-1]: + if m.shape[-1] < 5: print(" [!] Instance is too short! : {}".format(wavpath)) self.item_list[index] = self.item_list[index + 1] feat_path = self.item_list[index] diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index e1c4365f..9b637a6a 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -42,7 +42,7 @@ class MelResNet(nn.Module): self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) self.batch_norm = nn.BatchNorm1d(compute_dims) self.layers = nn.ModuleList() - for i in range(res_blocks): + for _ in range(res_blocks): self.layers.append(ResBlock(compute_dims)) self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) @@ -365,7 +365,8 @@ class WaveRNN(nn.Module): (i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio), ) - def get_gru_cell(self, gru): + @staticmethod + def get_gru_cell(gru): gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) gru_cell.weight_hh.data = gru.weight_hh_l0.data gru_cell.weight_ih.data = gru.weight_ih_l0.data @@ -373,13 +374,14 @@ class WaveRNN(nn.Module): gru_cell.bias_ih.data = gru.bias_ih_l0.data return gru_cell - def pad_tensor(self, x, pad, side="both"): + @staticmethod + def pad_tensor(x, pad, side="both"): # NB - this is just a quick method i need right now # i.e., it won't generalise to other shapes/dims b, t, c = x.size() total = t + 2 * pad if side == "both" else t + pad padded = torch.zeros(b, total, c).cuda() - if side == "before" or side == "both": + if side in ("before", "both"): padded[:, pad : pad + t, :] = x elif side == "after": padded[:, :t, :] = x diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py index bfcbdd3f..705c14dc 100644 --- a/TTS/vocoder/utils/distribution.py +++ b/TTS/vocoder/utils/distribution.py @@ -11,7 +11,11 @@ def gaussian_loss(y_hat, y, log_std_min=-7.0): mean = y_hat[:, :, :1] log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) # TODO: replace with pytorch dist - log_probs = -0.5 * (- math.log(2.0 * math.pi) - 2. * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))) + log_probs = -0.5 * ( + -math.log(2.0 * math.pi) + - 2.0 * log_std + - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)) + ) return log_probs.squeeze().mean() @@ -19,7 +23,10 @@ def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0): assert y_hat.size(2) == 2 mean = y_hat[:, :, :1] log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) - dist = Normal(mean, torch.exp(log_std), ) + dist = Normal( + mean, + torch.exp(log_std), + ) sample = dist.sample() sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) del dist @@ -36,11 +43,12 @@ def log_sum_exp(x): # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py -def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, - log_scale_min=None, reduce=True): +def discretized_mix_logistic_loss( + y_hat, y, num_classes=65536, log_scale_min=None, reduce=True +): if log_scale_min is None: log_scale_min = float(np.log(1e-14)) - y_hat = y_hat.permute(0,2,1) + y_hat = y_hat.permute(0, 2, 1) assert y_hat.dim() == 3 assert y_hat.size(1) % 3 == 0 nr_mix = y_hat.size(1) // 3 @@ -50,17 +58,17 @@ def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, # unpack parameters. (B, T, num_mixtures) x 3 logit_probs = y_hat[:, :, :nr_mix] - means = y_hat[:, :, nr_mix:2 * nr_mix] - log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) # B x T x 1 -> B x T x num_mixtures y = y.expand_as(means) centered_y = y - means inv_stdv = torch.exp(-log_scales) - plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) + plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) cdf_plus = torch.sigmoid(plus_in) - min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) + min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) cdf_min = torch.sigmoid(min_in) # log probability for edge case of 0 (before scaling) @@ -77,34 +85,35 @@ def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, mid_in = inv_stdv * centered_y # log probability in the center of the bin, to be used in extreme cases # (not actually used in our code) - log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) # tf equivalent - """ - log_probs = tf.where(x < -0.999, log_cdf_plus, - tf.where(x > 0.999, log_one_minus_cdf_min, - tf.where(cdf_delta > 1e-5, - tf.log(tf.maximum(cdf_delta, 1e-12)), - log_pdf_mid - np.log(127.5)))) - """ + + # log_probs = tf.where(x < -0.999, log_cdf_plus, + # tf.where(x > 0.999, log_one_minus_cdf_min, + # tf.where(cdf_delta > 1e-5, + # tf.log(tf.maximum(cdf_delta, 1e-12)), + # log_pdf_mid - np.log(127.5)))) + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value # for num_classes=65536 case? 1e-7? not sure.. inner_inner_cond = (cdf_delta > 1e-5).float() - inner_inner_out = inner_inner_cond * \ - torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ - (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) + inner_inner_out = inner_inner_cond * torch.log( + torch.clamp(cdf_delta, min=1e-12) + ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) inner_cond = (y > 0.999).float() - inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out + inner_out = ( + inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + ) cond = (y < -0.999).float() - log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out log_probs = log_probs + F.log_softmax(logit_probs, -1) if reduce: return -torch.mean(log_sum_exp(log_probs)) - else: - return -log_sum_exp(log_probs).unsqueeze(-1) + return -log_sum_exp(log_probs).unsqueeze(-1) def sample_from_discretized_mix_logistic(y, log_scale_min=None): @@ -127,26 +136,27 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None): # sample mixture indicator from softmax temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) - temp = logit_probs.data - torch.log(- torch.log(temp)) + temp = logit_probs.data - torch.log(-torch.log(temp)) _, argmax = temp.max(dim=-1) # (B, T) -> (B, T, nr_mix) one_hot = to_one_hot(argmax, nr_mix) # select logistic parameters - means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) - log_scales = torch.clamp(torch.sum( - y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp( + torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min + ) # sample from logistic & clip to interval # we don't actually round to the nearest 8bit value when sampling u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) - x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) - x = torch.clamp(torch.clamp(x, min=-1.), max=1.) + x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) return x -def to_one_hot(tensor, n, fill_with=1.): +def to_one_hot(tensor, n, fill_with=1.0): # we perform one hot encore with respect to the last axis one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() if tensor.is_cuda: