mirror of https://github.com/coqui-ai/TTS.git
Ton of linter fixes
This commit is contained in:
parent
58117c0c6e
commit
d6b9a5738a
12
train.py
12
train.py
|
@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (count_parameters, create_experiment_folder,
|
||||||
from TTS.utils.io import (save_best_model, save_checkpoint,
|
from TTS.utils.io import (save_best_model, save_checkpoint,
|
||||||
load_config, copy_config_file)
|
load_config, copy_config_file)
|
||||||
from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
|
from TTS.utils.training import (NoamLR, check_update, adam_weight_decay,
|
||||||
gradual_training_scheduler, set_weight_decay)
|
gradual_training_scheduler, set_weight_decay,
|
||||||
|
setup_torch_training_env)
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.console_logger import ConsoleLogger
|
from TTS.utils.console_logger import ConsoleLogger
|
||||||
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||||
|
@ -32,13 +33,8 @@ from TTS.datasets.preprocess import load_meta_data
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.measures import alignment_diagonal_score
|
from TTS.utils.measures import alignment_diagonal_score
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = True
|
|
||||||
torch.backends.cudnn.benchmark = False
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
torch.manual_seed(54321)
|
|
||||||
use_cuda = torch.cuda.is_available()
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
print(" > Using CUDA: ", use_cuda)
|
|
||||||
print(" > Number of GPUs: ", num_gpus)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||||
|
|
|
@ -2,6 +2,17 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||||
|
torch.backends.cudnn.enabled = cudnn_enable
|
||||||
|
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||||
|
torch.manual_seed(54321)
|
||||||
|
use_cuda = torch.cuda.is_available()
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
print(" > Using CUDA: ", use_cuda)
|
||||||
|
print(" > Number of GPUs: ", num_gpus)
|
||||||
|
return use_cuda, num_gpus
|
||||||
|
|
||||||
|
|
||||||
def check_update(model, grad_clip, ignore_stopnet=False):
|
def check_update(model, grad_clip, ignore_stopnet=False):
|
||||||
r'''Check model gradient against unexpected jumps and failures'''
|
r'''Check model gradient against unexpected jumps and failures'''
|
||||||
skip_flag = False
|
skip_flag = False
|
||||||
|
|
|
@ -3,29 +3,10 @@ import glob
|
||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset
|
||||||
from multiprocessing import Manager
|
from multiprocessing import Manager
|
||||||
|
|
||||||
|
|
||||||
def create_dataloader(hp, args, train):
|
|
||||||
dataset = MelFromDisk(hp, args, train)
|
|
||||||
|
|
||||||
if train:
|
|
||||||
return DataLoader(dataset=dataset,
|
|
||||||
batch_size=hp.train.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=hp.train.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=True)
|
|
||||||
else:
|
|
||||||
return DataLoader(dataset=dataset,
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=hp.train.num_workers,
|
|
||||||
pin_memory=True,
|
|
||||||
drop_last=False)
|
|
||||||
|
|
||||||
|
|
||||||
class GANDataset(Dataset):
|
class GANDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
GAN Dataset searchs for all the wav files under root path
|
GAN Dataset searchs for all the wav files under root path
|
||||||
|
@ -55,26 +36,26 @@ class GANDataset(Dataset):
|
||||||
self.return_segments = return_segments
|
self.return_segments = return_segments
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.use_noise_augment = use_noise_augment
|
self.use_noise_augment = use_noise_augment
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
||||||
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
||||||
|
|
||||||
# map G and D instances
|
# map G and D instances
|
||||||
self.G_to_D_mappings = [i for i in range(len(self.item_list))]
|
self.G_to_D_mappings = list(range(len(self.item_list)))
|
||||||
self.shuffle_mapping()
|
self.shuffle_mapping()
|
||||||
|
|
||||||
# cache acoustic features
|
# cache acoustic features
|
||||||
if use_cache:
|
if use_cache:
|
||||||
self.create_feature_cache()
|
self.create_feature_cache()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def create_feature_cache(self):
|
def create_feature_cache(self):
|
||||||
self.manager = Manager()
|
self.manager = Manager()
|
||||||
self.cache = self.manager.list()
|
self.cache = self.manager.list()
|
||||||
self.cache += [None for _ in range(len(self.item_list))]
|
self.cache += [None for _ in range(len(self.item_list))]
|
||||||
|
|
||||||
def find_wav_files(self, path):
|
@staticmethod
|
||||||
|
def find_wav_files(path):
|
||||||
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
|
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -88,7 +69,6 @@ class GANDataset(Dataset):
|
||||||
item1 = self.load_item(idx)
|
item1 = self.load_item(idx)
|
||||||
item2 = self.load_item(idx2)
|
item2 = self.load_item(idx2)
|
||||||
return item1, item2
|
return item1, item2
|
||||||
else:
|
|
||||||
item1 = self.load_item(idx)
|
item1 = self.load_item(idx)
|
||||||
return item1
|
return item1
|
||||||
|
|
||||||
|
|
|
@ -51,13 +51,13 @@ class STFTLoss(nn.Module):
|
||||||
|
|
||||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
n_ffts=[1024, 2048, 512],
|
n_ffts=(1024, 2048, 512),
|
||||||
hop_lengths=[120, 240, 50],
|
hop_lengths=(120, 240, 50),
|
||||||
win_lengths=[600, 1200, 240]):
|
win_lengths=(600, 1200, 240)):
|
||||||
super(MultiScaleSTFTLoss, self).__init__()
|
super(MultiScaleSTFTLoss, self).__init__()
|
||||||
self.loss_funcs = torch.nn.ModuleList()
|
self.loss_funcs = torch.nn.ModuleList()
|
||||||
for idx in range(len(n_ffts)):
|
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
|
||||||
self.loss_funcs.append(STFTLoss(n_ffts[idx], hop_lengths[idx], win_lengths[idx]))
|
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
|
||||||
|
|
||||||
def forward(self, y_hat, y):
|
def forward(self, y_hat, y):
|
||||||
N = len(self.loss_funcs)
|
N = len(self.loss_funcs)
|
||||||
|
@ -81,20 +81,14 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||||
|
|
||||||
class MSEGLoss(nn.Module):
|
class MSEGLoss(nn.Module):
|
||||||
""" Mean Squared Generator Loss """
|
""" Mean Squared Generator Loss """
|
||||||
def __init__(self,):
|
def forward(self, score_fake):
|
||||||
super(MSEGLoss, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, score_fake, ):
|
|
||||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||||
return loss_fake
|
return loss_fake
|
||||||
|
|
||||||
|
|
||||||
class HingeGLoss(nn.Module):
|
class HingeGLoss(nn.Module):
|
||||||
""" Hinge Discriminator Loss """
|
""" Hinge Discriminator Loss """
|
||||||
def __init__(self,):
|
def forward(self, score_fake):
|
||||||
super(HingeGLoss, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, score_fake, score_real):
|
|
||||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||||
return loss_fake
|
return loss_fake
|
||||||
|
|
||||||
|
@ -106,9 +100,6 @@ class HingeGLoss(nn.Module):
|
||||||
|
|
||||||
class MSEDLoss(nn.Module):
|
class MSEDLoss(nn.Module):
|
||||||
""" Mean Squared Discriminator Loss """
|
""" Mean Squared Discriminator Loss """
|
||||||
def __init__(self,):
|
|
||||||
super(MSEDLoss, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, score_fake, score_real):
|
def forward(self, score_fake, score_real):
|
||||||
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
||||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||||
|
@ -118,9 +109,6 @@ class MSEDLoss(nn.Module):
|
||||||
|
|
||||||
class HingeDLoss(nn.Module):
|
class HingeDLoss(nn.Module):
|
||||||
""" Hinge Discriminator Loss """
|
""" Hinge Discriminator Loss """
|
||||||
def __init__(self,):
|
|
||||||
super(HingeDLoss, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, score_fake, score_real):
|
def forward(self, score_fake, score_real):
|
||||||
loss_real = torch.mean(F.relu(1. - score_real))
|
loss_real = torch.mean(F.relu(1. - score_real))
|
||||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||||
|
@ -129,13 +117,10 @@ class HingeDLoss(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MelganFeatureLoss(nn.Module):
|
class MelganFeatureLoss(nn.Module):
|
||||||
def __init__(self, ):
|
|
||||||
super(MelganFeatureLoss, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, fake_feats, real_feats):
|
def forward(self, fake_feats, real_feats):
|
||||||
loss_feats = 0
|
loss_feats = 0
|
||||||
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
||||||
loss_feats += hp.model.feat_match * torch.mean(torch.abs(fake_feat - real_feat))
|
loss_feats += torch.mean(torch.abs(fake_feat - real_feat))
|
||||||
return loss_feats
|
return loss_feats
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,7 +132,7 @@ class MelganFeatureLoss(nn.Module):
|
||||||
class GeneratorLoss(nn.Module):
|
class GeneratorLoss(nn.Module):
|
||||||
def __init__(self, C):
|
def __init__(self, C):
|
||||||
super(GeneratorLoss, self).__init__()
|
super(GeneratorLoss, self).__init__()
|
||||||
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\
|
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||||
|
|
||||||
self.use_stft_loss = C.use_stft_loss
|
self.use_stft_loss = C.use_stft_loss
|
||||||
|
@ -228,7 +213,7 @@ class GeneratorLoss(nn.Module):
|
||||||
class DiscriminatorLoss(nn.Module):
|
class DiscriminatorLoss(nn.Module):
|
||||||
def __init__(self, C):
|
def __init__(self, C):
|
||||||
super(DiscriminatorLoss, self).__init__()
|
super(DiscriminatorLoss, self).__init__()
|
||||||
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\
|
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||||
|
|
||||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||||
|
|
|
@ -1,7 +1,4 @@
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -44,6 +44,9 @@ class PQMF(torch.nn.Module):
|
||||||
|
|
||||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.analysis(x)
|
||||||
|
|
||||||
def analysis(self, x):
|
def analysis(self, x):
|
||||||
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
|
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
|
||||||
|
|
||||||
|
|
|
@ -1,127 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
# Copyright 2020 Tomoki Hayashi
|
|
||||||
# MIT License (https://opensource.org/licenses/MIT)
|
|
||||||
|
|
||||||
"""Pseudo QMF modules."""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from scipy.signal import kaiser
|
|
||||||
|
|
||||||
|
|
||||||
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
|
|
||||||
"""Design prototype filter for PQMF.
|
|
||||||
|
|
||||||
This method is based on `A Kaiser window approach for the design of prototype
|
|
||||||
filters of cosine modulated filterbanks`_.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
taps (int): The number of filter taps.
|
|
||||||
cutoff_ratio (float): Cut-off frequency ratio.
|
|
||||||
beta (float): Beta coefficient for kaiser window.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ndarray: Impluse response of prototype filter (taps + 1,).
|
|
||||||
|
|
||||||
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
|
||||||
https://ieeexplore.ieee.org/abstract/document/681427
|
|
||||||
|
|
||||||
"""
|
|
||||||
# check the arguments are valid
|
|
||||||
assert taps % 2 == 0, "The number of taps mush be even number."
|
|
||||||
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
|
|
||||||
|
|
||||||
# make initial filter
|
|
||||||
omega_c = np.pi * cutoff_ratio
|
|
||||||
with np.errstate(invalid='ignore'):
|
|
||||||
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
|
|
||||||
/ (np.pi * (np.arange(taps + 1) - 0.5 * taps))
|
|
||||||
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
|
||||||
|
|
||||||
# apply kaiser window
|
|
||||||
w = kaiser(taps + 1, beta)
|
|
||||||
h = h_i * w
|
|
||||||
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class PQMF(torch.nn.Module):
|
|
||||||
"""PQMF module.
|
|
||||||
|
|
||||||
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
|
||||||
|
|
||||||
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
|
||||||
https://ieeexplore.ieee.org/document/258122
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
|
|
||||||
"""Initilize PQMF module.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
subbands (int): The number of subbands.
|
|
||||||
taps (int): The number of filter taps.
|
|
||||||
cutoff_ratio (float): Cut-off frequency ratio.
|
|
||||||
beta (float): Beta coefficient for kaiser window.
|
|
||||||
|
|
||||||
"""
|
|
||||||
super(PQMF, self).__init__()
|
|
||||||
|
|
||||||
# define filter coefficient
|
|
||||||
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
|
||||||
h_analysis = np.zeros((subbands, len(h_proto)))
|
|
||||||
h_synthesis = np.zeros((subbands, len(h_proto)))
|
|
||||||
for k in range(subbands):
|
|
||||||
h_analysis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + (-1) ** k * np.pi / 4)
|
|
||||||
h_synthesis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) - (-1) ** k * np.pi / 4)
|
|
||||||
|
|
||||||
# convert to tensor
|
|
||||||
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
|
||||||
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
|
||||||
|
|
||||||
# register coefficients as beffer
|
|
||||||
self.register_buffer("analysis_filter", analysis_filter)
|
|
||||||
self.register_buffer("synthesis_filter", synthesis_filter)
|
|
||||||
|
|
||||||
# filter for downsampling & upsampling
|
|
||||||
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
|
||||||
for k in range(subbands):
|
|
||||||
updown_filter[k, k, 0] = 1.0
|
|
||||||
self.register_buffer("updown_filter", updown_filter)
|
|
||||||
self.subbands = subbands
|
|
||||||
|
|
||||||
# keep padding info
|
|
||||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
|
||||||
|
|
||||||
def analysis(self, x):
|
|
||||||
"""Analysis with PQMF.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor (B, 1, T).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (B, subbands, T // subbands).
|
|
||||||
|
|
||||||
"""
|
|
||||||
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
|
||||||
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
|
||||||
|
|
||||||
def synthesis(self, x):
|
|
||||||
"""Synthesis with PQMF.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x (Tensor): Input tensor (B, subbands, T // subbands).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: Output tensor (B, 1, T).
|
|
||||||
|
|
||||||
"""
|
|
||||||
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
|
|
||||||
# Not sure this is the correct way, it is better to check again.
|
|
||||||
# TODO(kan-bayashi): Understand the reconstruction procedure
|
|
||||||
x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
|
|
||||||
x = F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
|
||||||
return x
|
|
|
@ -1,7 +1,5 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,7 +30,7 @@ class MelganDiscriminator(nn.Module):
|
||||||
|
|
||||||
# downsampling layers
|
# downsampling layers
|
||||||
layer_in_channels = base_channels
|
layer_in_channels = base_channels
|
||||||
for idx, downsample_factor in enumerate(downsample_factors):
|
for downsample_factor in downsample_factors:
|
||||||
layer_out_channels = min(layer_in_channels * downsample_factor,
|
layer_out_channels = min(layer_in_channels * downsample_factor,
|
||||||
max_channels)
|
max_channels)
|
||||||
layer_kernel_size = downsample_factor * 10 + 1
|
layer_kernel_size = downsample_factor * 10 + 1
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
import math
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
|
||||||
from torch.nn.utils import weight_norm
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
from TTS.vocoder.layers.melgan import ResidualStack
|
from TTS.vocoder.layers.melgan import ResidualStack
|
||||||
|
|
|
@ -22,7 +22,7 @@ ok_ljspeech = os.path.exists(test_data_path)
|
||||||
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers):
|
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers):
|
||||||
''' run dataloader with given parameters and check conditions '''
|
''' run dataloader with given parameters and check conditions '''
|
||||||
ap = AudioProcessor(**C.audio)
|
ap = AudioProcessor(**C.audio)
|
||||||
eval_items, train_items = load_wav_data(test_data_path, 10)
|
_, train_items = load_wav_data(test_data_path, 10)
|
||||||
dataset = GANDataset(ap,
|
dataset = GANDataset(ap,
|
||||||
train_items,
|
train_items,
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
|
@ -44,9 +44,9 @@ def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, us
|
||||||
|
|
||||||
# return random segments or return the whole audio
|
# return random segments or return the whole audio
|
||||||
if return_segments:
|
if return_segments:
|
||||||
for item1, item2 in loader:
|
for item1, _ in loader:
|
||||||
feat1, wav1 = item1
|
feat1, wav1 = item1
|
||||||
feat2, wav2 = item2
|
# feat2, wav2 = item2
|
||||||
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
|
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
|
||||||
|
|
||||||
# check shapes
|
# check shapes
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.vocoder.layers.losses import TorchSTFT, STFTLoss, MultiScaleSTFTLoss
|
from TTS.vocoder.layers.losses import TorchSTFT, STFTLoss, MultiScaleSTFTLoss
|
||||||
|
@ -24,7 +23,7 @@ def test_torch_stft():
|
||||||
torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length)
|
torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length)
|
||||||
# librosa stft
|
# librosa stft
|
||||||
wav = ap.load_wav(WAV_FILE)
|
wav = ap.load_wav(WAV_FILE)
|
||||||
M_librosa = abs(ap._stft(wav))
|
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
||||||
# torch stft
|
# torch stft
|
||||||
wav = torch.from_numpy(wav[None, :]).float()
|
wav = torch.from_numpy(wav[None, :]).float()
|
||||||
M_torch = torch_stft(wav)
|
M_torch = torch_stft(wav)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||||
|
|
|
@ -1,16 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from librosa.core import load
|
from librosa.core import load
|
||||||
|
|
||||||
from TTS.tests import get_tests_path, get_tests_input_path, get_tests_output_path
|
from TTS.tests import get_tests_path, get_tests_input_path
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.utils.io import load_config
|
|
||||||
from TTS.vocoder.layers.pqmf import PQMF
|
from TTS.vocoder.layers.pqmf import PQMF
|
||||||
from TTS.vocoder.layers.pqmf2 import PQMF as PQMF2
|
|
||||||
|
|
||||||
|
|
||||||
TESTS_PATH = get_tests_path()
|
TESTS_PATH = get_tests_path()
|
||||||
|
|
|
@ -17,7 +17,7 @@ from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||||
from TTS.utils.io import copy_config_file, load_config
|
from TTS.utils.io import copy_config_file, load_config
|
||||||
from TTS.utils.radam import RAdam
|
from TTS.utils.radam import RAdam
|
||||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||||
from TTS.utils.training import NoamLR
|
from TTS.utils.training import setup_torch_training_env, NoamLR
|
||||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
|
@ -29,13 +29,8 @@ from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
|
||||||
setup_discriminator,
|
setup_discriminator,
|
||||||
setup_generator)
|
setup_generator)
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = True
|
|
||||||
torch.backends.cudnn.benchmark = True
|
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||||
torch.manual_seed(54321)
|
|
||||||
use_cuda = torch.cuda.is_available()
|
|
||||||
num_gpus = torch.cuda.device_count()
|
|
||||||
print(" > Using CUDA: ", use_cuda)
|
|
||||||
print(" > Number of GPUs: ", num_gpus)
|
|
||||||
|
|
||||||
|
|
||||||
def setup_loader(ap, is_val=False, verbose=False):
|
def setup_loader(ap, is_val=False, verbose=False):
|
||||||
|
@ -49,10 +44,11 @@ def setup_loader(ap, is_val=False, verbose=False):
|
||||||
pad_short=c.pad_short,
|
pad_short=c.pad_short,
|
||||||
conv_pad=c.conv_pad,
|
conv_pad=c.conv_pad,
|
||||||
is_training=not is_val,
|
is_training=not is_val,
|
||||||
return_segments=False if is_val else True,
|
return_segments=not is_val,
|
||||||
use_noise_augment=c.use_noise_augment,
|
use_noise_augment=c.use_noise_augment,
|
||||||
use_cache=c.use_cache,
|
use_cache=c.use_cache,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
dataset.shuffle_mapping()
|
||||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
loader = DataLoader(dataset,
|
loader = DataLoader(dataset,
|
||||||
batch_size=1 if is_val else c.batch_size,
|
batch_size=1 if is_val else c.batch_size,
|
||||||
|
@ -81,11 +77,11 @@ def format_data(data):
|
||||||
return c_G, x_G, c_D, x_D
|
return c_G, x_G, c_D, x_D
|
||||||
|
|
||||||
# return a whole audio segment
|
# return a whole audio segment
|
||||||
c, x = data
|
co, x = data
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
c = c.cuda(non_blocking=True)
|
co = co.cuda(non_blocking=True)
|
||||||
x = x.cuda(non_blocking=True)
|
x = x.cuda(non_blocking=True)
|
||||||
return c, x
|
return co, x, None, None
|
||||||
|
|
||||||
|
|
||||||
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||||
|
@ -306,7 +302,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
c_G, y_G = format_data(data)
|
c_G, y_G, _, _ = format_data(data)
|
||||||
loader_time = time.time() - end_time
|
loader_time = time.time() - end_time
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
@ -416,7 +412,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
||||||
model_disc.load_state_dict(checkpoint['model_disc'])
|
model_disc.load_state_dict(checkpoint['model_disc'])
|
||||||
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
||||||
except:
|
except KeyError:
|
||||||
print(" > Partial model initialization.")
|
print(" > Partial model initialization.")
|
||||||
model_dict = model_gen.state_dict()
|
model_dict = model_gen.state_dict()
|
||||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||||
|
|
|
@ -91,7 +91,7 @@ class ConsoleLogger():
|
||||||
sign = ''
|
sign = ''
|
||||||
elif diff > 0:
|
elif diff > 0:
|
||||||
color = tcolors.FAIL
|
color = tcolors.FAIL
|
||||||
sing = '+'
|
sign = '+'
|
||||||
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
|
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
|
||||||
self.old_eval_loss_dict = avg_loss_dict
|
self.old_eval_loss_dict = avg_loss_dict
|
||||||
print(log_text, flush=True)
|
print(log_text, flush=True)
|
||||||
|
|
|
@ -2,7 +2,6 @@ import os
|
||||||
import torch
|
import torch
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, optimizer, model_disc, optimizer_disc, current_step,
|
def save_model(model, optimizer, model_disc, optimizer_disc, current_step,
|
||||||
epoch, output_path, **kwargs):
|
epoch, output_path, **kwargs):
|
||||||
model_state = model.state_dict()
|
model_state = model.state_dict()
|
||||||
|
|
Loading…
Reference in New Issue