mirror of https://github.com/coqui-ai/TTS.git
fixing pylint errors
This commit is contained in:
parent
878b7c373e
commit
e8294cb9db
|
@ -1,8 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import time
|
import time
|
||||||
|
@ -11,7 +8,8 @@ import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
|
||||||
|
# from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.tts.utils.visual import plot_spectrogram
|
from TTS.tts.utils.visual import plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
@ -30,7 +28,6 @@ from TTS.utils.generic_utils import (
|
||||||
)
|
)
|
||||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||||
from TTS.vocoder.datasets.preprocess import (
|
from TTS.vocoder.datasets.preprocess import (
|
||||||
load_wav_data,
|
|
||||||
find_feat_files,
|
find_feat_files,
|
||||||
load_wav_feat_data,
|
load_wav_feat_data,
|
||||||
preprocess_wav_files,
|
preprocess_wav_files,
|
||||||
|
@ -322,7 +319,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
|
CONFIG.data_path, mel_feat_path, CONFIG.eval_split_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print(f" > No feature data found. Preprocessing...")
|
print(" > No feature data found. Preprocessing...")
|
||||||
# preprocessing feature data from given wav files
|
# preprocessing feature data from given wav files
|
||||||
preprocess_wav_files(OUT_PATH, CONFIG, ap)
|
preprocess_wav_files(OUT_PATH, CONFIG, ap)
|
||||||
eval_data, train_data = load_wav_feat_data(
|
eval_data, train_data = load_wav_feat_data(
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
@ -42,7 +40,7 @@ class WaveRNNDataset(Dataset):
|
||||||
wavpath, feat_path = self.item_list[index]
|
wavpath, feat_path = self.item_list[index]
|
||||||
m = np.load(feat_path.replace("/quant/", "/mel/"))
|
m = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||||
# x = self.wav_cache[index]
|
# x = self.wav_cache[index]
|
||||||
if 5 > m.shape[-1]:
|
if m.shape[-1] < 5:
|
||||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||||
self.item_list[index] = self.item_list[index + 1]
|
self.item_list[index] = self.item_list[index + 1]
|
||||||
feat_path = self.item_list[index]
|
feat_path = self.item_list[index]
|
||||||
|
|
|
@ -42,7 +42,7 @@ class MelResNet(nn.Module):
|
||||||
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
||||||
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
for i in range(res_blocks):
|
for _ in range(res_blocks):
|
||||||
self.layers.append(ResBlock(compute_dims))
|
self.layers.append(ResBlock(compute_dims))
|
||||||
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
||||||
|
|
||||||
|
@ -365,7 +365,8 @@ class WaveRNN(nn.Module):
|
||||||
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
|
(i * b_size, seq_len * b_size, b_size, gen_rate, realtime_ratio),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_gru_cell(self, gru):
|
@staticmethod
|
||||||
|
def get_gru_cell(gru):
|
||||||
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
||||||
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
||||||
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
||||||
|
@ -373,13 +374,14 @@ class WaveRNN(nn.Module):
|
||||||
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
||||||
return gru_cell
|
return gru_cell
|
||||||
|
|
||||||
def pad_tensor(self, x, pad, side="both"):
|
@staticmethod
|
||||||
|
def pad_tensor(x, pad, side="both"):
|
||||||
# NB - this is just a quick method i need right now
|
# NB - this is just a quick method i need right now
|
||||||
# i.e., it won't generalise to other shapes/dims
|
# i.e., it won't generalise to other shapes/dims
|
||||||
b, t, c = x.size()
|
b, t, c = x.size()
|
||||||
total = t + 2 * pad if side == "both" else t + pad
|
total = t + 2 * pad if side == "both" else t + pad
|
||||||
padded = torch.zeros(b, total, c).cuda()
|
padded = torch.zeros(b, total, c).cuda()
|
||||||
if side == "before" or side == "both":
|
if side in ("before", "both"):
|
||||||
padded[:, pad : pad + t, :] = x
|
padded[:, pad : pad + t, :] = x
|
||||||
elif side == "after":
|
elif side == "after":
|
||||||
padded[:, :t, :] = x
|
padded[:, :t, :] = x
|
||||||
|
|
|
@ -11,7 +11,11 @@ def gaussian_loss(y_hat, y, log_std_min=-7.0):
|
||||||
mean = y_hat[:, :, :1]
|
mean = y_hat[:, :, :1]
|
||||||
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
||||||
# TODO: replace with pytorch dist
|
# TODO: replace with pytorch dist
|
||||||
log_probs = -0.5 * (- math.log(2.0 * math.pi) - 2. * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)))
|
log_probs = -0.5 * (
|
||||||
|
-math.log(2.0 * math.pi)
|
||||||
|
- 2.0 * log_std
|
||||||
|
- torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))
|
||||||
|
)
|
||||||
return log_probs.squeeze().mean()
|
return log_probs.squeeze().mean()
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +23,10 @@ def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0):
|
||||||
assert y_hat.size(2) == 2
|
assert y_hat.size(2) == 2
|
||||||
mean = y_hat[:, :, :1]
|
mean = y_hat[:, :, :1]
|
||||||
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min)
|
||||||
dist = Normal(mean, torch.exp(log_std), )
|
dist = Normal(
|
||||||
|
mean,
|
||||||
|
torch.exp(log_std),
|
||||||
|
)
|
||||||
sample = dist.sample()
|
sample = dist.sample()
|
||||||
sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor)
|
sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor)
|
||||||
del dist
|
del dist
|
||||||
|
@ -36,11 +43,12 @@ def log_sum_exp(x):
|
||||||
|
|
||||||
|
|
||||||
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
||||||
def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
|
def discretized_mix_logistic_loss(
|
||||||
log_scale_min=None, reduce=True):
|
y_hat, y, num_classes=65536, log_scale_min=None, reduce=True
|
||||||
|
):
|
||||||
if log_scale_min is None:
|
if log_scale_min is None:
|
||||||
log_scale_min = float(np.log(1e-14))
|
log_scale_min = float(np.log(1e-14))
|
||||||
y_hat = y_hat.permute(0,2,1)
|
y_hat = y_hat.permute(0, 2, 1)
|
||||||
assert y_hat.dim() == 3
|
assert y_hat.dim() == 3
|
||||||
assert y_hat.size(1) % 3 == 0
|
assert y_hat.size(1) % 3 == 0
|
||||||
nr_mix = y_hat.size(1) // 3
|
nr_mix = y_hat.size(1) // 3
|
||||||
|
@ -50,17 +58,17 @@ def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
|
||||||
|
|
||||||
# unpack parameters. (B, T, num_mixtures) x 3
|
# unpack parameters. (B, T, num_mixtures) x 3
|
||||||
logit_probs = y_hat[:, :, :nr_mix]
|
logit_probs = y_hat[:, :, :nr_mix]
|
||||||
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
means = y_hat[:, :, nr_mix : 2 * nr_mix]
|
||||||
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
|
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
|
||||||
|
|
||||||
# B x T x 1 -> B x T x num_mixtures
|
# B x T x 1 -> B x T x num_mixtures
|
||||||
y = y.expand_as(means)
|
y = y.expand_as(means)
|
||||||
|
|
||||||
centered_y = y - means
|
centered_y = y - means
|
||||||
inv_stdv = torch.exp(-log_scales)
|
inv_stdv = torch.exp(-log_scales)
|
||||||
plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
|
plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
|
||||||
cdf_plus = torch.sigmoid(plus_in)
|
cdf_plus = torch.sigmoid(plus_in)
|
||||||
min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
|
min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
|
||||||
cdf_min = torch.sigmoid(min_in)
|
cdf_min = torch.sigmoid(min_in)
|
||||||
|
|
||||||
# log probability for edge case of 0 (before scaling)
|
# log probability for edge case of 0 (before scaling)
|
||||||
|
@ -77,33 +85,34 @@ def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
|
||||||
mid_in = inv_stdv * centered_y
|
mid_in = inv_stdv * centered_y
|
||||||
# log probability in the center of the bin, to be used in extreme cases
|
# log probability in the center of the bin, to be used in extreme cases
|
||||||
# (not actually used in our code)
|
# (not actually used in our code)
|
||||||
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
|
log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
|
||||||
|
|
||||||
# tf equivalent
|
# tf equivalent
|
||||||
"""
|
|
||||||
log_probs = tf.where(x < -0.999, log_cdf_plus,
|
# log_probs = tf.where(x < -0.999, log_cdf_plus,
|
||||||
tf.where(x > 0.999, log_one_minus_cdf_min,
|
# tf.where(x > 0.999, log_one_minus_cdf_min,
|
||||||
tf.where(cdf_delta > 1e-5,
|
# tf.where(cdf_delta > 1e-5,
|
||||||
tf.log(tf.maximum(cdf_delta, 1e-12)),
|
# tf.log(tf.maximum(cdf_delta, 1e-12)),
|
||||||
log_pdf_mid - np.log(127.5))))
|
# log_pdf_mid - np.log(127.5))))
|
||||||
"""
|
|
||||||
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
||||||
# for num_classes=65536 case? 1e-7? not sure..
|
# for num_classes=65536 case? 1e-7? not sure..
|
||||||
inner_inner_cond = (cdf_delta > 1e-5).float()
|
inner_inner_cond = (cdf_delta > 1e-5).float()
|
||||||
|
|
||||||
inner_inner_out = inner_inner_cond * \
|
inner_inner_out = inner_inner_cond * torch.log(
|
||||||
torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
|
torch.clamp(cdf_delta, min=1e-12)
|
||||||
(1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
||||||
inner_cond = (y > 0.999).float()
|
inner_cond = (y > 0.999).float()
|
||||||
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
|
inner_out = (
|
||||||
|
inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
|
||||||
|
)
|
||||||
cond = (y < -0.999).float()
|
cond = (y < -0.999).float()
|
||||||
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
|
log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
|
||||||
|
|
||||||
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
||||||
|
|
||||||
if reduce:
|
if reduce:
|
||||||
return -torch.mean(log_sum_exp(log_probs))
|
return -torch.mean(log_sum_exp(log_probs))
|
||||||
else:
|
|
||||||
return -log_sum_exp(log_probs).unsqueeze(-1)
|
return -log_sum_exp(log_probs).unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,26 +136,27 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
||||||
|
|
||||||
# sample mixture indicator from softmax
|
# sample mixture indicator from softmax
|
||||||
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||||
temp = logit_probs.data - torch.log(- torch.log(temp))
|
temp = logit_probs.data - torch.log(-torch.log(temp))
|
||||||
_, argmax = temp.max(dim=-1)
|
_, argmax = temp.max(dim=-1)
|
||||||
|
|
||||||
# (B, T) -> (B, T, nr_mix)
|
# (B, T) -> (B, T, nr_mix)
|
||||||
one_hot = to_one_hot(argmax, nr_mix)
|
one_hot = to_one_hot(argmax, nr_mix)
|
||||||
# select logistic parameters
|
# select logistic parameters
|
||||||
means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
|
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
|
||||||
log_scales = torch.clamp(torch.sum(
|
log_scales = torch.clamp(
|
||||||
y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
|
torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min
|
||||||
|
)
|
||||||
# sample from logistic & clip to interval
|
# sample from logistic & clip to interval
|
||||||
# we don't actually round to the nearest 8bit value when sampling
|
# we don't actually round to the nearest 8bit value when sampling
|
||||||
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
||||||
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
|
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
|
||||||
|
|
||||||
x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
|
x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def to_one_hot(tensor, n, fill_with=1.):
|
def to_one_hot(tensor, n, fill_with=1.0):
|
||||||
# we perform one hot encore with respect to the last axis
|
# we perform one hot encore with respect to the last axis
|
||||||
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
||||||
if tensor.is_cuda:
|
if tensor.is_cuda:
|
||||||
|
|
Loading…
Reference in New Issue