mirror of https://github.com/coqui-ai/TTS.git
initial commit intro. to vocoder submodule
This commit is contained in:
parent
e9f3a476cf
commit
0bb0ba182e
|
@ -0,0 +1,35 @@
|
|||
# Mozilla TTS Vocoders (Experimental)
|
||||
|
||||
We provide here different vocoder implementations which can be combined with our TTS models to enable "FASTER THAN REAL-TIME" end-to-end TTS stack.
|
||||
|
||||
Currently, there are implementations of the following models.
|
||||
|
||||
- Melgan
|
||||
- MultiBand-Melgan
|
||||
- GAN-TTS (Discriminator Only)
|
||||
|
||||
It is also very easy to adapt different vocoder models as we provide here a flexible and modular (but not too modular) framework.
|
||||
|
||||
## Training a model
|
||||
|
||||
You can see here an example (Soon)[Colab Notebook]() training MelGAN with LJSpeech dataset.
|
||||
|
||||
In order to train a new model, you need to collecto all your wav files under a common parent folder and give this path to `data_path` field in '''config.json'''
|
||||
|
||||
You need to define other relevant parameters in your ```config.json``` and then start traning with the following command from Mozilla TTS root path.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='1' python vocoder/train.py --config_path path/to/config.json```
|
||||
|
||||
Exampled config files can be found under `vocoder/configs/` folder.
|
||||
|
||||
You can continue a previous training by the following command.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='1' python vocoder/train.py --continue_path path/to/your/model/folder```
|
||||
|
||||
You can fine-tune a pre-trained model by the following command.
|
||||
|
||||
```CUDA_VISIBLE_DEVICES='1' python vocoder/train.py --restore_path path/to/your/model.pth.tar```
|
||||
|
||||
Restoring a model starts a new training in a different output folder. It only restores model weights with the given checkpoint file. However, continuing a training starts from the same conditions the previous training run left off.
|
||||
|
||||
You can also follow your training runs on Tensorboard as you do with our TTS models.
|
|
@ -0,0 +1,138 @@
|
|||
{
|
||||
"run_name": "melgan",
|
||||
"run_description": "melgan initial run",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
// stft parameters
|
||||
"num_freq": 513, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// Griffin-Lim
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
// "distributed":{
|
||||
// "backend": "nccl",
|
||||
// "url": "tcp:\/\/localhost:54321"
|
||||
// },
|
||||
|
||||
// MODEL PARAMETERS
|
||||
"use_pqmf": true,
|
||||
|
||||
// LOSS PARAMETERS
|
||||
"use_stft_loss": true,
|
||||
"use_mse_gan_loss": true,
|
||||
"use_hinge_gan_loss": false,
|
||||
"use_feat_match_loss": false, // use only with melgan discriminators
|
||||
|
||||
"stft_loss_alpha": 1,
|
||||
"mse_gan_loss_alpha": 1,
|
||||
"hinge_gan_loss_alpha": 1,
|
||||
"feat_match_loss_alpha": 10.0,
|
||||
|
||||
"stft_loss_params": {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
},
|
||||
"target_loss": "avg_G_loss", // loss value to pick the best model
|
||||
|
||||
// DISCRIMINATOR
|
||||
"discriminator_model": "melgan_multiscale_discriminator",
|
||||
"discriminator_model_params":{
|
||||
"base_channels": 16,
|
||||
"max_channels":1024,
|
||||
"downsample_factors":[4, 4, 4, 4]
|
||||
},
|
||||
"steps_to_start_discriminator": 100000, // steps required to start GAN trainining.1
|
||||
|
||||
// "discriminator_model": "random_window_discriminator",
|
||||
// "discriminator_model_params":{
|
||||
// "uncond_disc_donwsample_factors": [8, 4],
|
||||
// "cond_disc_downsample_factors": [[8, 4, 2, 2, 2], [8, 4, 2, 2], [8, 4, 2], [8, 4], [4, 2, 2]],
|
||||
// "cond_disc_out_channels": [[128, 128, 256, 256], [128, 256, 256], [128, 256], [256], [128, 256]],
|
||||
// "window_sizes": [512, 1024, 2048, 4096, 8192]
|
||||
// },
|
||||
|
||||
|
||||
// GENERATOR
|
||||
"generator_model": "multiband_melgan_generator",
|
||||
"generator_model_params": {
|
||||
"upsample_factors":[2 ,2, 4, 4],
|
||||
"num_res_blocks": 4
|
||||
},
|
||||
|
||||
// DATASET
|
||||
"data_path": "/home/erogol/Data/LJSpeech-1.1/wavs/",
|
||||
"seq_len": 16384,
|
||||
"pad_short": 2000,
|
||||
"conv_pad": 0,
|
||||
"use_noise_augment": true,
|
||||
"use_cache": true,
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// TRAINING
|
||||
"batch_size": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 10, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": true, // use noam warmup and lr schedule.
|
||||
"grad_clip": 1.0, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"lr_gen": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_disc": 0.0001,
|
||||
"warmup_steps_gen": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"warmup_steps_disc": 4000,
|
||||
"gen_clip_grad": 10.0,
|
||||
"disc_clip_grad": 10.0,
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log traning on console.
|
||||
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||
"save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
// DATA LOADING
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"eval_split_size": 10,
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/LJSpeech/"
|
||||
}
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
import os
|
||||
import glob
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from multiprocessing import Manager
|
||||
|
||||
|
||||
def create_dataloader(hp, args, train):
|
||||
dataset = MelFromDisk(hp, args, train)
|
||||
|
||||
if train:
|
||||
return DataLoader(dataset=dataset,
|
||||
batch_size=hp.train.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=hp.train.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
else:
|
||||
return DataLoader(dataset=dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=hp.train.num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=False)
|
||||
|
||||
|
||||
class GANDataset(Dataset):
|
||||
"""
|
||||
GAN Dataset searchs for all the wav files under root path
|
||||
and converts them to acoustic features on the fly and returns
|
||||
random segments of (audio, feature) couples.
|
||||
"""
|
||||
def __init__(self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False):
|
||||
|
||||
self.ap = ap
|
||||
self.item_list = items
|
||||
self.seq_len = seq_len
|
||||
self.hop_len = hop_len
|
||||
self.pad_short = pad_short
|
||||
self.conv_pad = conv_pad
|
||||
self.is_training = is_training
|
||||
self.return_segments = return_segments
|
||||
self.use_cache = use_cache
|
||||
self.use_noise_augment = use_noise_augment
|
||||
|
||||
assert seq_len % hop_len == 0, " [!] seq_len has to be a multiple of hop_len."
|
||||
self.feat_frame_len = seq_len // hop_len + (2 * conv_pad)
|
||||
|
||||
# map G and D instances
|
||||
self.G_to_D_mappings = [i for i in range(len(self.item_list))]
|
||||
self.shuffle_mapping()
|
||||
|
||||
# cache acoustic features
|
||||
if use_cache:
|
||||
self.create_feature_cache()
|
||||
|
||||
|
||||
|
||||
def create_feature_cache(self):
|
||||
self.manager = Manager()
|
||||
self.cache = self.manager.list()
|
||||
self.cache += [None for _ in range(len(self.item_list))]
|
||||
|
||||
def find_wav_files(self, path):
|
||||
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
""" Return different items for Generator and Discriminator and
|
||||
cache acoustic features """
|
||||
if self.return_segments:
|
||||
idx2 = self.G_to_D_mappings[idx]
|
||||
item1 = self.load_item(idx)
|
||||
item2 = self.load_item(idx2)
|
||||
return item1, item2
|
||||
else:
|
||||
item1 = self.load_item(idx)
|
||||
return item1
|
||||
|
||||
def shuffle_mapping(self):
|
||||
random.shuffle(self.G_to_D_mappings)
|
||||
|
||||
def load_item(self, idx):
|
||||
""" load (audio, feat) couple """
|
||||
wavpath = self.item_list[idx]
|
||||
# print(wavpath)
|
||||
|
||||
if self.use_cache and self.cache[idx] is not None:
|
||||
audio, mel = self.cache[idx]
|
||||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = self.ap.melspectrogram(audio)
|
||||
|
||||
if len(audio) < self.seq_len + self.pad_short:
|
||||
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
||||
mode='constant', constant_values=0.0)
|
||||
|
||||
# correct the audio length wrt padding applied in stft
|
||||
audio = np.pad(audio, (0, self.hop_len), mode="edge")
|
||||
audio = audio[:mel.shape[-1] * self.hop_len]
|
||||
assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}'
|
||||
|
||||
audio = torch.from_numpy(audio).float().unsqueeze(0)
|
||||
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||
|
||||
if self.return_segments:
|
||||
max_mel_start = mel.shape[1] - self.feat_frame_len
|
||||
mel_start = random.randint(0, max_mel_start)
|
||||
mel_end = mel_start + self.feat_frame_len
|
||||
mel = mel[:, mel_start:mel_end]
|
||||
|
||||
audio_start = mel_start * self.hop_len
|
||||
audio = audio[:, audio_start:audio_start +
|
||||
self.seq_len]
|
||||
|
||||
if self.use_noise_augment and self.is_training and self.return_segments:
|
||||
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||
return (mel, audio)
|
|
@ -0,0 +1,16 @@
|
|||
import glob
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def find_wav_files(data_path):
|
||||
wav_paths = glob.glob(os.path.join(data_path, '**', '*.wav'), recursive=True)
|
||||
return wav_paths
|
||||
|
||||
|
||||
def load_wav_data(data_path, eval_split_size):
|
||||
wav_paths = find_wav_files(data_path)
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(wav_paths)
|
||||
return wav_paths[:eval_split_size], wav_paths[eval_split_size:]
|
|
@ -0,0 +1,273 @@
|
|||
import torch
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class TorchSTFT():
|
||||
def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = getattr(torch, window)(win_length)
|
||||
|
||||
def __call__(self, x):
|
||||
# B x D x T x 2
|
||||
o = torch.stft(x,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="constant", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
return torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
|
||||
|
||||
#################################
|
||||
# GENERATOR LOSSES
|
||||
#################################
|
||||
|
||||
|
||||
class STFTLoss(nn.Module):
|
||||
def __init__(self, n_fft, hop_length, win_length):
|
||||
super(STFTLoss, self).__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.stft = TorchSTFT(n_fft, hop_length, win_length)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
y_hat_M = self.stft(y_hat)
|
||||
y_M = self.stft(y)
|
||||
# magnitude loss
|
||||
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||
# spectral convergence loss
|
||||
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||
def __init__(self,
|
||||
n_ffts=[1024, 2048, 512],
|
||||
hop_lengths=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240]):
|
||||
super(MultiScaleSTFTLoss, self).__init__()
|
||||
self.loss_funcs = torch.nn.ModuleList()
|
||||
for idx in range(len(n_ffts)):
|
||||
self.loss_funcs.append(STFTLoss(n_ffts[idx], hop_lengths[idx], win_lengths[idx]))
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
N = len(self.loss_funcs)
|
||||
loss_sc = 0
|
||||
loss_mag = 0
|
||||
for f in self.loss_funcs:
|
||||
lm, lsc = f(y_hat, y)
|
||||
loss_mag += lm
|
||||
loss_sc += lsc
|
||||
loss_sc /= N
|
||||
loss_mag /= N
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class MSEGLoss(nn.Module):
|
||||
""" Mean Squared Generator Loss """
|
||||
def __init__(self,):
|
||||
super(MSEGLoss, self).__init__()
|
||||
|
||||
def forward(self, score_fake, ):
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
return loss_fake
|
||||
|
||||
|
||||
class HingeGLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
def __init__(self,):
|
||||
super(HingeGLoss, self).__init__()
|
||||
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
return loss_fake
|
||||
|
||||
|
||||
##################################
|
||||
# DISCRIMINATOR LOSSES
|
||||
##################################
|
||||
|
||||
|
||||
class MSEDLoss(nn.Module):
|
||||
""" Mean Squared Discriminator Loss """
|
||||
def __init__(self,):
|
||||
super(MSEDLoss, self).__init__()
|
||||
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class HingeDLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
def __init__(self,):
|
||||
super(HingeDLoss, self).__init__()
|
||||
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(F.relu(1. - score_real))
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
def __init__(self, ):
|
||||
super(MelganFeatureLoss, self).__init__()
|
||||
|
||||
def forward(self, fake_feats, real_feats):
|
||||
loss_feats = 0
|
||||
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
||||
loss_feats += hp.model.feat_match * torch.mean(torch.abs(fake_feat - real_feat))
|
||||
return loss_feats
|
||||
|
||||
|
||||
##################################
|
||||
# LOSS WRAPPERS
|
||||
##################################
|
||||
|
||||
|
||||
class GeneratorLoss(nn.Module):
|
||||
def __init__(self, C):
|
||||
super(GeneratorLoss, self).__init__()
|
||||
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\
|
||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_stft_loss = C.use_stft_loss
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
self.use_feat_match_loss = C.use_feat_match_loss
|
||||
|
||||
self.stft_loss_alpha = C.stft_loss_alpha
|
||||
self.mse_gan_loss_alpha = C.mse_gan_loss_alpha
|
||||
self.hinge_gan_loss_alpha = C.hinge_gan_loss_alpha
|
||||
self.feat_match_loss_alpha = C.feat_match_loss_alpha
|
||||
|
||||
if C.use_stft_loss:
|
||||
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEGLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
self.hinge_loss = HingeGLoss()
|
||||
if C.use_feat_match_loss:
|
||||
self.feat_match_loss = MelganFeatureLoss()
|
||||
|
||||
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
||||
# STFT Loss
|
||||
if self.use_stft_loss:
|
||||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1))
|
||||
return_dict['G_stft_loss_mg'] = stft_loss_mg
|
||||
return_dict['G_stft_loss_sc'] = stft_loss_sc
|
||||
loss += self.stft_loss_alpha * (stft_loss_mg + stft_loss_sc)
|
||||
|
||||
# Fake Losses
|
||||
if self.use_mse_gan_loss and scores_fake is not None:
|
||||
mse_fake_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
for score_fake in scores_fake:
|
||||
fake_loss = self.mse_loss(score_fake)
|
||||
mse_fake_loss += fake_loss
|
||||
else:
|
||||
fake_loss = self.mse_loss(scores_fake)
|
||||
mse_fake_loss = fake_loss
|
||||
return_dict['G_mse_fake_loss'] = mse_fake_loss
|
||||
loss += self.mse_gan_loss_alpha * mse_fake_loss
|
||||
|
||||
if self.use_hinge_gan_loss and not scores_fake is not None:
|
||||
hinge_fake_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
for score_fake in scores_fake:
|
||||
fake_loss = self.hinge_loss(score_fake)
|
||||
hinge_fake_loss += fake_loss
|
||||
else:
|
||||
fake_loss = self.hinge_loss(scores_fake)
|
||||
hinge_fake_loss = fake_loss
|
||||
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
|
||||
loss += self.hinge_gan_loss_alpha * hinge_fake_loss
|
||||
|
||||
# Feature Matching Loss
|
||||
if self.use_feat_match_loss and not feats_fake:
|
||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||
return_dict['G_feat_match_loss'] = feat_match_loss
|
||||
loss += self.feat_match_loss_alpha * feat_match_loss
|
||||
return_dict['G_loss'] = loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
def __init__(self, C):
|
||||
super(DiscriminatorLoss, self).__init__()
|
||||
assert C.use_mse_gan_loss and C.use_hinge_gan_loss == False,\
|
||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
|
||||
self.mse_gan_loss_alpha = C.mse_gan_loss_alpha
|
||||
self.hinge_gan_loss_alpha = C.hinge_gan_loss_alpha
|
||||
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEDLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
self.hinge_loss = HingeDLoss()
|
||||
|
||||
def forward(self, scores_fake, scores_real):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
||||
if self.use_mse_gan_loss:
|
||||
mse_gan_loss = 0
|
||||
mse_gan_real_loss = 0
|
||||
mse_gan_fake_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||
total_loss, real_loss, fake_loss = self.mse_loss(score_fake, score_real)
|
||||
mse_gan_loss += total_loss
|
||||
mse_gan_real_loss += real_loss
|
||||
mse_gan_fake_loss += fake_loss
|
||||
else:
|
||||
total_loss, real_loss, fake_loss = self.mse_loss(scores_fake, scores_real)
|
||||
mse_gan_loss = total_loss
|
||||
mse_gan_real_loss = real_loss
|
||||
mse_gan_fake_loss = fake_loss
|
||||
return_dict['D_mse_gan_loss'] = mse_gan_loss
|
||||
return_dict['D_mse_gan_real_loss'] = mse_gan_real_loss
|
||||
return_dict['D_mse_gan_fake_loss'] = mse_gan_fake_loss
|
||||
loss += self.mse_gan_loss_alpha * mse_gan_loss
|
||||
|
||||
if self.use_hinge_gan_loss:
|
||||
hinge_gan_loss = 0
|
||||
hinge_gan_real_loss = 0
|
||||
hinge_gan_fake_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
for score_fake, score_real in zip(scores_fake, scores_real):
|
||||
total_loss, real_loss, fake_loss = self.hinge_loss(score_fake, score_real)
|
||||
hinge_gan_loss += total_loss
|
||||
hinge_gan_real_loss += real_loss
|
||||
hinge_gan_fake_loss += fake_loss
|
||||
else:
|
||||
total_loss, real_loss, fake_loss = self.hinge_loss(scores_fake, scores_real)
|
||||
hinge_gan_loss = total_loss
|
||||
hinge_gan_real_loss = real_loss
|
||||
hinge_gan_fake_loss = fake_loss
|
||||
return_dict['D_hinge_gan_loss'] = hinge_gan_loss
|
||||
return_dict['D_hinge_gan_real_loss'] = hinge_gan_real_loss
|
||||
return_dict['D_hinge_gan_fake_loss'] = hinge_gan_fake_loss
|
||||
loss += self.hinge_gan_loss_alpha * hinge_gan_loss
|
||||
|
||||
return_dict['D_loss'] = loss
|
||||
return return_dict
|
|
@ -0,0 +1,48 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
class ResidualStack(nn.Module):
|
||||
def __init__(self, channels, num_res_blocks, kernel_size):
|
||||
super(ResidualStack, self).__init__()
|
||||
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
|
||||
base_padding = (kernel_size - 1) // 2
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for idx in range(num_res_blocks):
|
||||
layer_kernel_size = kernel_size
|
||||
layer_dilation = layer_kernel_size**idx
|
||||
layer_padding = base_padding * layer_dilation
|
||||
self.blocks += [nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=layer_padding,
|
||||
bias=True)),
|
||||
nn.LeakyReLU(0.2),
|
||||
weight_norm(
|
||||
nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
|
||||
)]
|
||||
|
||||
self.shortcuts = nn.ModuleList([
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=1,
|
||||
bias=True)) for i in range(num_res_blocks)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
x = shortcut(x) + block(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
nn.utils.remove_weight_norm(block[2])
|
||||
nn.utils.remove_weight_norm(block[4])
|
||||
nn.utils.remove_weight_norm(shortcut)
|
|
@ -0,0 +1,55 @@
|
|||
"""Pseudo QMF modules."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy import signal as sig
|
||||
|
||||
|
||||
# adapted from
|
||||
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
|
||||
class PQMF(torch.nn.Module):
|
||||
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
|
||||
super(PQMF, self).__init__()
|
||||
|
||||
self.N = N
|
||||
self.taps = taps
|
||||
self.cutoff = cutoff
|
||||
self.beta = beta
|
||||
|
||||
QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta))
|
||||
H = np.zeros((N, len(QMF)))
|
||||
G = np.zeros((N, len(QMF)))
|
||||
for k in range(N):
|
||||
constant_factor = (2 * k + 1) * (np.pi /
|
||||
(2 * N)) * (np.arange(taps + 1) -
|
||||
((taps - 1) / 2))
|
||||
phase = (-1)**k * np.pi / 4
|
||||
H[k] = 2 * QMF * np.cos(constant_factor + phase)
|
||||
|
||||
G[k] = 2 * QMF * np.cos(constant_factor - phase)
|
||||
|
||||
H = torch.from_numpy(H[:, None, :]).float()
|
||||
G = torch.from_numpy(G[None, :, :]).float()
|
||||
|
||||
self.register_buffer("H", H)
|
||||
self.register_buffer("G", G)
|
||||
|
||||
updown_filter = torch.zeros((N, N, N)).float()
|
||||
for k in range(N):
|
||||
updown_filter[k, k, 0] = 1.0
|
||||
self.register_buffer("updown_filter", updown_filter)
|
||||
self.N = N
|
||||
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def analysis(self, x):
|
||||
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
|
||||
|
||||
def synthesis(self, x):
|
||||
x = F.conv_transpose1d(x,
|
||||
self.updown_filter * self.N,
|
||||
stride=self.N)
|
||||
x = F.conv1d(x, self.G, padding=self.taps // 2)
|
||||
return x
|
|
@ -0,0 +1,127 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Tomoki Hayashi
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
"""Pseudo QMF modules."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from scipy.signal import kaiser
|
||||
|
||||
|
||||
def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
|
||||
"""Design prototype filter for PQMF.
|
||||
|
||||
This method is based on `A Kaiser window approach for the design of prototype
|
||||
filters of cosine modulated filterbanks`_.
|
||||
|
||||
Args:
|
||||
taps (int): The number of filter taps.
|
||||
cutoff_ratio (float): Cut-off frequency ratio.
|
||||
beta (float): Beta coefficient for kaiser window.
|
||||
|
||||
Returns:
|
||||
ndarray: Impluse response of prototype filter (taps + 1,).
|
||||
|
||||
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
||||
https://ieeexplore.ieee.org/abstract/document/681427
|
||||
|
||||
"""
|
||||
# check the arguments are valid
|
||||
assert taps % 2 == 0, "The number of taps mush be even number."
|
||||
assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
|
||||
|
||||
# make initial filter
|
||||
omega_c = np.pi * cutoff_ratio
|
||||
with np.errstate(invalid='ignore'):
|
||||
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
|
||||
/ (np.pi * (np.arange(taps + 1) - 0.5 * taps))
|
||||
h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
||||
|
||||
# apply kaiser window
|
||||
w = kaiser(taps + 1, beta)
|
||||
h = h_i * w
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class PQMF(torch.nn.Module):
|
||||
"""PQMF module.
|
||||
|
||||
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
||||
|
||||
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
||||
https://ieeexplore.ieee.org/document/258122
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
|
||||
"""Initilize PQMF module.
|
||||
|
||||
Args:
|
||||
subbands (int): The number of subbands.
|
||||
taps (int): The number of filter taps.
|
||||
cutoff_ratio (float): Cut-off frequency ratio.
|
||||
beta (float): Beta coefficient for kaiser window.
|
||||
|
||||
"""
|
||||
super(PQMF, self).__init__()
|
||||
|
||||
# define filter coefficient
|
||||
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
||||
h_analysis = np.zeros((subbands, len(h_proto)))
|
||||
h_synthesis = np.zeros((subbands, len(h_proto)))
|
||||
for k in range(subbands):
|
||||
h_analysis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + (-1) ** k * np.pi / 4)
|
||||
h_synthesis[k] = 2 * h_proto * np.cos((2 * k + 1) * (np.pi / (2 * subbands)) * (np.arange(taps + 1) - ((taps - 1) / 2)) - (-1) ** k * np.pi / 4)
|
||||
|
||||
# convert to tensor
|
||||
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
||||
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
||||
|
||||
# register coefficients as beffer
|
||||
self.register_buffer("analysis_filter", analysis_filter)
|
||||
self.register_buffer("synthesis_filter", synthesis_filter)
|
||||
|
||||
# filter for downsampling & upsampling
|
||||
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
||||
for k in range(subbands):
|
||||
updown_filter[k, k, 0] = 1.0
|
||||
self.register_buffer("updown_filter", updown_filter)
|
||||
self.subbands = subbands
|
||||
|
||||
# keep padding info
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def analysis(self, x):
|
||||
"""Analysis with PQMF.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, subbands, T // subbands).
|
||||
|
||||
"""
|
||||
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
||||
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
||||
|
||||
def synthesis(self, x):
|
||||
"""Synthesis with PQMF.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, subbands, T // subbands).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, 1, T).
|
||||
|
||||
"""
|
||||
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
|
||||
# Not sure this is the correct way, it is better to check again.
|
||||
# TODO(kan-bayashi): Understand the reconstruction procedure
|
||||
x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
|
||||
x = F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
||||
return x
|
|
@ -0,0 +1,640 @@
|
|||
0.0000000e+000
|
||||
-5.5252865e-004
|
||||
-5.6176926e-004
|
||||
-4.9475181e-004
|
||||
-4.8752280e-004
|
||||
-4.8937912e-004
|
||||
-5.0407143e-004
|
||||
-5.2265643e-004
|
||||
-5.4665656e-004
|
||||
-5.6778026e-004
|
||||
-5.8709305e-004
|
||||
-6.1327474e-004
|
||||
-6.3124935e-004
|
||||
-6.5403334e-004
|
||||
-6.7776908e-004
|
||||
-6.9416146e-004
|
||||
-7.1577365e-004
|
||||
-7.2550431e-004
|
||||
-7.4409419e-004
|
||||
-7.4905981e-004
|
||||
-7.6813719e-004
|
||||
-7.7248486e-004
|
||||
-7.8343323e-004
|
||||
-7.7798695e-004
|
||||
-7.8036647e-004
|
||||
-7.8014496e-004
|
||||
-7.7579773e-004
|
||||
-7.6307936e-004
|
||||
-7.5300014e-004
|
||||
-7.3193572e-004
|
||||
-7.2153920e-004
|
||||
-6.9179375e-004
|
||||
-6.6504151e-004
|
||||
-6.3415949e-004
|
||||
-5.9461189e-004
|
||||
-5.5645764e-004
|
||||
-5.1455722e-004
|
||||
-4.6063255e-004
|
||||
-4.0951215e-004
|
||||
-3.5011759e-004
|
||||
-2.8969812e-004
|
||||
-2.0983373e-004
|
||||
-1.4463809e-004
|
||||
-6.1733441e-005
|
||||
1.3494974e-005
|
||||
1.0943831e-004
|
||||
2.0430171e-004
|
||||
2.9495311e-004
|
||||
4.0265402e-004
|
||||
5.1073885e-004
|
||||
6.2393761e-004
|
||||
7.4580259e-004
|
||||
8.6084433e-004
|
||||
9.8859883e-004
|
||||
1.1250155e-003
|
||||
1.2577885e-003
|
||||
1.3902495e-003
|
||||
1.5443220e-003
|
||||
1.6868083e-003
|
||||
1.8348265e-003
|
||||
1.9841141e-003
|
||||
2.1461584e-003
|
||||
2.3017255e-003
|
||||
2.4625617e-003
|
||||
2.6201759e-003
|
||||
2.7870464e-003
|
||||
2.9469448e-003
|
||||
3.1125421e-003
|
||||
3.2739613e-003
|
||||
3.4418874e-003
|
||||
3.6008268e-003
|
||||
3.7603923e-003
|
||||
3.9207432e-003
|
||||
4.0819753e-003
|
||||
4.2264269e-003
|
||||
4.3730720e-003
|
||||
4.5209853e-003
|
||||
4.6606461e-003
|
||||
4.7932561e-003
|
||||
4.9137604e-003
|
||||
5.0393023e-003
|
||||
5.1407354e-003
|
||||
5.2461166e-003
|
||||
5.3471681e-003
|
||||
5.4196776e-003
|
||||
5.4876040e-003
|
||||
5.5475715e-003
|
||||
5.5938023e-003
|
||||
5.6220643e-003
|
||||
5.6455197e-003
|
||||
5.6389200e-003
|
||||
5.6266114e-003
|
||||
5.5917129e-003
|
||||
5.5404364e-003
|
||||
5.4753783e-003
|
||||
5.3838976e-003
|
||||
5.2715759e-003
|
||||
5.1382275e-003
|
||||
4.9839688e-003
|
||||
4.8109469e-003
|
||||
4.6039530e-003
|
||||
4.3801862e-003
|
||||
4.1251642e-003
|
||||
3.8456408e-003
|
||||
3.5401247e-003
|
||||
3.2091886e-003
|
||||
2.8446758e-003
|
||||
2.4508540e-003
|
||||
2.0274176e-003
|
||||
1.5784683e-003
|
||||
1.0902329e-003
|
||||
5.8322642e-004
|
||||
2.7604519e-005
|
||||
-5.4642809e-004
|
||||
-1.1568136e-003
|
||||
-1.8039473e-003
|
||||
-2.4826724e-003
|
||||
-3.1933778e-003
|
||||
-3.9401124e-003
|
||||
-4.7222596e-003
|
||||
-5.5337211e-003
|
||||
-6.3792293e-003
|
||||
-7.2615817e-003
|
||||
-8.1798233e-003
|
||||
-9.1325330e-003
|
||||
-1.0115022e-002
|
||||
-1.1131555e-002
|
||||
-1.2185000e-002
|
||||
-1.3271822e-002
|
||||
-1.4390467e-002
|
||||
-1.5540555e-002
|
||||
-1.6732471e-002
|
||||
-1.7943338e-002
|
||||
-1.9187243e-002
|
||||
-2.0453179e-002
|
||||
-2.1746755e-002
|
||||
-2.3068017e-002
|
||||
-2.4416099e-002
|
||||
-2.5787585e-002
|
||||
-2.7185943e-002
|
||||
-2.8607217e-002
|
||||
-3.0050266e-002
|
||||
-3.1501761e-002
|
||||
-3.2975408e-002
|
||||
-3.4462095e-002
|
||||
-3.5969756e-002
|
||||
-3.7481285e-002
|
||||
-3.9005368e-002
|
||||
-4.0534917e-002
|
||||
-4.2064909e-002
|
||||
-4.3609754e-002
|
||||
-4.5148841e-002
|
||||
-4.6684303e-002
|
||||
-4.8216572e-002
|
||||
-4.9738576e-002
|
||||
-5.1255616e-002
|
||||
-5.2763075e-002
|
||||
-5.4245277e-002
|
||||
-5.5717365e-002
|
||||
-5.7161645e-002
|
||||
-5.8591568e-002
|
||||
-5.9983748e-002
|
||||
-6.1345517e-002
|
||||
-6.2685781e-002
|
||||
-6.3971590e-002
|
||||
-6.5224711e-002
|
||||
-6.6436751e-002
|
||||
-6.7607599e-002
|
||||
-6.8704383e-002
|
||||
-6.9763024e-002
|
||||
-7.0762871e-002
|
||||
-7.1700267e-002
|
||||
-7.2568258e-002
|
||||
-7.3362026e-002
|
||||
-7.4100364e-002
|
||||
-7.4745256e-002
|
||||
-7.5313734e-002
|
||||
-7.5800836e-002
|
||||
-7.6199248e-002
|
||||
-7.6499217e-002
|
||||
-7.6709349e-002
|
||||
-7.6817398e-002
|
||||
-7.6823001e-002
|
||||
-7.6720492e-002
|
||||
-7.6505072e-002
|
||||
-7.6174832e-002
|
||||
-7.5730576e-002
|
||||
-7.5157626e-002
|
||||
-7.4466439e-002
|
||||
-7.3640601e-002
|
||||
-7.2677464e-002
|
||||
-7.1582636e-002
|
||||
-7.0353307e-002
|
||||
-6.8966401e-002
|
||||
-6.7452502e-002
|
||||
-6.5769067e-002
|
||||
-6.3944481e-002
|
||||
-6.1960278e-002
|
||||
-5.9816657e-002
|
||||
-5.7515269e-002
|
||||
-5.5046003e-002
|
||||
-5.2409382e-002
|
||||
-4.9597868e-002
|
||||
-4.6630331e-002
|
||||
-4.3476878e-002
|
||||
-4.0145828e-002
|
||||
-3.6641812e-002
|
||||
-3.2958393e-002
|
||||
-2.9082401e-002
|
||||
-2.5030756e-002
|
||||
-2.0799707e-002
|
||||
-1.6370126e-002
|
||||
-1.1762383e-002
|
||||
-6.9636862e-003
|
||||
-1.9765601e-003
|
||||
3.2086897e-003
|
||||
8.5711749e-003
|
||||
1.4128883e-002
|
||||
1.9883413e-002
|
||||
2.5822729e-002
|
||||
3.1953127e-002
|
||||
3.8277657e-002
|
||||
4.4780682e-002
|
||||
5.1480418e-002
|
||||
5.8370533e-002
|
||||
6.5440985e-002
|
||||
7.2694330e-002
|
||||
8.0137293e-002
|
||||
8.7754754e-002
|
||||
9.5553335e-002
|
||||
1.0353295e-001
|
||||
1.1168269e-001
|
||||
1.2000780e-001
|
||||
1.2850029e-001
|
||||
1.3715518e-001
|
||||
1.4597665e-001
|
||||
1.5496071e-001
|
||||
1.6409589e-001
|
||||
1.7338082e-001
|
||||
1.8281725e-001
|
||||
1.9239667e-001
|
||||
2.0212502e-001
|
||||
2.1197359e-001
|
||||
2.2196527e-001
|
||||
2.3206909e-001
|
||||
2.4230169e-001
|
||||
2.5264803e-001
|
||||
2.6310533e-001
|
||||
2.7366340e-001
|
||||
2.8432142e-001
|
||||
2.9507167e-001
|
||||
3.0590986e-001
|
||||
3.1682789e-001
|
||||
3.2781137e-001
|
||||
3.3887227e-001
|
||||
3.4999141e-001
|
||||
3.6115899e-001
|
||||
3.7237955e-001
|
||||
3.8363500e-001
|
||||
3.9492118e-001
|
||||
4.0623177e-001
|
||||
4.1756969e-001
|
||||
4.2891199e-001
|
||||
4.4025538e-001
|
||||
4.5159965e-001
|
||||
4.6293081e-001
|
||||
4.7424532e-001
|
||||
4.8552531e-001
|
||||
4.9677083e-001
|
||||
5.0798175e-001
|
||||
5.1912350e-001
|
||||
5.3022409e-001
|
||||
5.4125534e-001
|
||||
5.5220513e-001
|
||||
5.6307891e-001
|
||||
5.7385241e-001
|
||||
5.8454032e-001
|
||||
5.9511231e-001
|
||||
6.0557835e-001
|
||||
6.1591099e-001
|
||||
6.2612427e-001
|
||||
6.3619801e-001
|
||||
6.4612697e-001
|
||||
6.5590163e-001
|
||||
6.6551399e-001
|
||||
6.7496632e-001
|
||||
6.8423533e-001
|
||||
6.9332824e-001
|
||||
7.0223887e-001
|
||||
7.1094104e-001
|
||||
7.1944626e-001
|
||||
7.2774489e-001
|
||||
7.3582118e-001
|
||||
7.4368279e-001
|
||||
7.5131375e-001
|
||||
7.5870808e-001
|
||||
7.6586749e-001
|
||||
7.7277809e-001
|
||||
7.7942875e-001
|
||||
7.8583531e-001
|
||||
7.9197358e-001
|
||||
7.9784664e-001
|
||||
8.0344858e-001
|
||||
8.0876950e-001
|
||||
8.1381913e-001
|
||||
8.1857760e-001
|
||||
8.2304199e-001
|
||||
8.2722753e-001
|
||||
8.3110385e-001
|
||||
8.3469374e-001
|
||||
8.3797173e-001
|
||||
8.4095414e-001
|
||||
8.4362383e-001
|
||||
8.4598185e-001
|
||||
8.4803158e-001
|
||||
8.4978052e-001
|
||||
8.5119715e-001
|
||||
8.5230470e-001
|
||||
8.5310209e-001
|
||||
8.5357206e-001
|
||||
8.5373856e-001
|
||||
8.5357206e-001
|
||||
8.5310209e-001
|
||||
8.5230470e-001
|
||||
8.5119715e-001
|
||||
8.4978052e-001
|
||||
8.4803158e-001
|
||||
8.4598185e-001
|
||||
8.4362383e-001
|
||||
8.4095414e-001
|
||||
8.3797173e-001
|
||||
8.3469374e-001
|
||||
8.3110385e-001
|
||||
8.2722753e-001
|
||||
8.2304199e-001
|
||||
8.1857760e-001
|
||||
8.1381913e-001
|
||||
8.0876950e-001
|
||||
8.0344858e-001
|
||||
7.9784664e-001
|
||||
7.9197358e-001
|
||||
7.8583531e-001
|
||||
7.7942875e-001
|
||||
7.7277809e-001
|
||||
7.6586749e-001
|
||||
7.5870808e-001
|
||||
7.5131375e-001
|
||||
7.4368279e-001
|
||||
7.3582118e-001
|
||||
7.2774489e-001
|
||||
7.1944626e-001
|
||||
7.1094104e-001
|
||||
7.0223887e-001
|
||||
6.9332824e-001
|
||||
6.8423533e-001
|
||||
6.7496632e-001
|
||||
6.6551399e-001
|
||||
6.5590163e-001
|
||||
6.4612697e-001
|
||||
6.3619801e-001
|
||||
6.2612427e-001
|
||||
6.1591099e-001
|
||||
6.0557835e-001
|
||||
5.9511231e-001
|
||||
5.8454032e-001
|
||||
5.7385241e-001
|
||||
5.6307891e-001
|
||||
5.5220513e-001
|
||||
5.4125534e-001
|
||||
5.3022409e-001
|
||||
5.1912350e-001
|
||||
5.0798175e-001
|
||||
4.9677083e-001
|
||||
4.8552531e-001
|
||||
4.7424532e-001
|
||||
4.6293081e-001
|
||||
4.5159965e-001
|
||||
4.4025538e-001
|
||||
4.2891199e-001
|
||||
4.1756969e-001
|
||||
4.0623177e-001
|
||||
3.9492118e-001
|
||||
3.8363500e-001
|
||||
3.7237955e-001
|
||||
3.6115899e-001
|
||||
3.4999141e-001
|
||||
3.3887227e-001
|
||||
3.2781137e-001
|
||||
3.1682789e-001
|
||||
3.0590986e-001
|
||||
2.9507167e-001
|
||||
2.8432142e-001
|
||||
2.7366340e-001
|
||||
2.6310533e-001
|
||||
2.5264803e-001
|
||||
2.4230169e-001
|
||||
2.3206909e-001
|
||||
2.2196527e-001
|
||||
2.1197359e-001
|
||||
2.0212502e-001
|
||||
1.9239667e-001
|
||||
1.8281725e-001
|
||||
1.7338082e-001
|
||||
1.6409589e-001
|
||||
1.5496071e-001
|
||||
1.4597665e-001
|
||||
1.3715518e-001
|
||||
1.2850029e-001
|
||||
1.2000780e-001
|
||||
1.1168269e-001
|
||||
1.0353295e-001
|
||||
9.5553335e-002
|
||||
8.7754754e-002
|
||||
8.0137293e-002
|
||||
7.2694330e-002
|
||||
6.5440985e-002
|
||||
5.8370533e-002
|
||||
5.1480418e-002
|
||||
4.4780682e-002
|
||||
3.8277657e-002
|
||||
3.1953127e-002
|
||||
2.5822729e-002
|
||||
1.9883413e-002
|
||||
1.4128883e-002
|
||||
8.5711749e-003
|
||||
3.2086897e-003
|
||||
-1.9765601e-003
|
||||
-6.9636862e-003
|
||||
-1.1762383e-002
|
||||
-1.6370126e-002
|
||||
-2.0799707e-002
|
||||
-2.5030756e-002
|
||||
-2.9082401e-002
|
||||
-3.2958393e-002
|
||||
-3.6641812e-002
|
||||
-4.0145828e-002
|
||||
-4.3476878e-002
|
||||
-4.6630331e-002
|
||||
-4.9597868e-002
|
||||
-5.2409382e-002
|
||||
-5.5046003e-002
|
||||
-5.7515269e-002
|
||||
-5.9816657e-002
|
||||
-6.1960278e-002
|
||||
-6.3944481e-002
|
||||
-6.5769067e-002
|
||||
-6.7452502e-002
|
||||
-6.8966401e-002
|
||||
-7.0353307e-002
|
||||
-7.1582636e-002
|
||||
-7.2677464e-002
|
||||
-7.3640601e-002
|
||||
-7.4466439e-002
|
||||
-7.5157626e-002
|
||||
-7.5730576e-002
|
||||
-7.6174832e-002
|
||||
-7.6505072e-002
|
||||
-7.6720492e-002
|
||||
-7.6823001e-002
|
||||
-7.6817398e-002
|
||||
-7.6709349e-002
|
||||
-7.6499217e-002
|
||||
-7.6199248e-002
|
||||
-7.5800836e-002
|
||||
-7.5313734e-002
|
||||
-7.4745256e-002
|
||||
-7.4100364e-002
|
||||
-7.3362026e-002
|
||||
-7.2568258e-002
|
||||
-7.1700267e-002
|
||||
-7.0762871e-002
|
||||
-6.9763024e-002
|
||||
-6.8704383e-002
|
||||
-6.7607599e-002
|
||||
-6.6436751e-002
|
||||
-6.5224711e-002
|
||||
-6.3971590e-002
|
||||
-6.2685781e-002
|
||||
-6.1345517e-002
|
||||
-5.9983748e-002
|
||||
-5.8591568e-002
|
||||
-5.7161645e-002
|
||||
-5.5717365e-002
|
||||
-5.4245277e-002
|
||||
-5.2763075e-002
|
||||
-5.1255616e-002
|
||||
-4.9738576e-002
|
||||
-4.8216572e-002
|
||||
-4.6684303e-002
|
||||
-4.5148841e-002
|
||||
-4.3609754e-002
|
||||
-4.2064909e-002
|
||||
-4.0534917e-002
|
||||
-3.9005368e-002
|
||||
-3.7481285e-002
|
||||
-3.5969756e-002
|
||||
-3.4462095e-002
|
||||
-3.2975408e-002
|
||||
-3.1501761e-002
|
||||
-3.0050266e-002
|
||||
-2.8607217e-002
|
||||
-2.7185943e-002
|
||||
-2.5787585e-002
|
||||
-2.4416099e-002
|
||||
-2.3068017e-002
|
||||
-2.1746755e-002
|
||||
-2.0453179e-002
|
||||
-1.9187243e-002
|
||||
-1.7943338e-002
|
||||
-1.6732471e-002
|
||||
-1.5540555e-002
|
||||
-1.4390467e-002
|
||||
-1.3271822e-002
|
||||
-1.2185000e-002
|
||||
-1.1131555e-002
|
||||
-1.0115022e-002
|
||||
-9.1325330e-003
|
||||
-8.1798233e-003
|
||||
-7.2615817e-003
|
||||
-6.3792293e-003
|
||||
-5.5337211e-003
|
||||
-4.7222596e-003
|
||||
-3.9401124e-003
|
||||
-3.1933778e-003
|
||||
-2.4826724e-003
|
||||
-1.8039473e-003
|
||||
-1.1568136e-003
|
||||
-5.4642809e-004
|
||||
2.7604519e-005
|
||||
5.8322642e-004
|
||||
1.0902329e-003
|
||||
1.5784683e-003
|
||||
2.0274176e-003
|
||||
2.4508540e-003
|
||||
2.8446758e-003
|
||||
3.2091886e-003
|
||||
3.5401247e-003
|
||||
3.8456408e-003
|
||||
4.1251642e-003
|
||||
4.3801862e-003
|
||||
4.6039530e-003
|
||||
4.8109469e-003
|
||||
4.9839688e-003
|
||||
5.1382275e-003
|
||||
5.2715759e-003
|
||||
5.3838976e-003
|
||||
5.4753783e-003
|
||||
5.5404364e-003
|
||||
5.5917129e-003
|
||||
5.6266114e-003
|
||||
5.6389200e-003
|
||||
5.6455197e-003
|
||||
5.6220643e-003
|
||||
5.5938023e-003
|
||||
5.5475715e-003
|
||||
5.4876040e-003
|
||||
5.4196776e-003
|
||||
5.3471681e-003
|
||||
5.2461166e-003
|
||||
5.1407354e-003
|
||||
5.0393023e-003
|
||||
4.9137604e-003
|
||||
4.7932561e-003
|
||||
4.6606461e-003
|
||||
4.5209853e-003
|
||||
4.3730720e-003
|
||||
4.2264269e-003
|
||||
4.0819753e-003
|
||||
3.9207432e-003
|
||||
3.7603923e-003
|
||||
3.6008268e-003
|
||||
3.4418874e-003
|
||||
3.2739613e-003
|
||||
3.1125421e-003
|
||||
2.9469448e-003
|
||||
2.7870464e-003
|
||||
2.6201759e-003
|
||||
2.4625617e-003
|
||||
2.3017255e-003
|
||||
2.1461584e-003
|
||||
1.9841141e-003
|
||||
1.8348265e-003
|
||||
1.6868083e-003
|
||||
1.5443220e-003
|
||||
1.3902495e-003
|
||||
1.2577885e-003
|
||||
1.1250155e-003
|
||||
9.8859883e-004
|
||||
8.6084433e-004
|
||||
7.4580259e-004
|
||||
6.2393761e-004
|
||||
5.1073885e-004
|
||||
4.0265402e-004
|
||||
2.9495311e-004
|
||||
2.0430171e-004
|
||||
1.0943831e-004
|
||||
1.3494974e-005
|
||||
-6.1733441e-005
|
||||
-1.4463809e-004
|
||||
-2.0983373e-004
|
||||
-2.8969812e-004
|
||||
-3.5011759e-004
|
||||
-4.0951215e-004
|
||||
-4.6063255e-004
|
||||
-5.1455722e-004
|
||||
-5.5645764e-004
|
||||
-5.9461189e-004
|
||||
-6.3415949e-004
|
||||
-6.6504151e-004
|
||||
-6.9179375e-004
|
||||
-7.2153920e-004
|
||||
-7.3193572e-004
|
||||
-7.5300014e-004
|
||||
-7.6307936e-004
|
||||
-7.7579773e-004
|
||||
-7.8014496e-004
|
||||
-7.8036647e-004
|
||||
-7.7798695e-004
|
||||
-7.8343323e-004
|
||||
-7.7248486e-004
|
||||
-7.6813719e-004
|
||||
-7.4905981e-004
|
||||
-7.4409419e-004
|
||||
-7.2550431e-004
|
||||
-7.1577365e-004
|
||||
-6.9416146e-004
|
||||
-6.7776908e-004
|
||||
-6.5403334e-004
|
||||
-6.3124935e-004
|
||||
-6.1327474e-004
|
||||
-5.8709305e-004
|
||||
-5.6778026e-004
|
||||
-5.4665656e-004
|
||||
-5.2265643e-004
|
||||
-5.0407143e-004
|
||||
-4.8937912e-004
|
||||
-4.8752280e-004
|
||||
-4.9475181e-004
|
||||
-5.6176926e-004
|
||||
-5.5252865e-004
|
|
@ -0,0 +1,80 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
class MelganDiscriminator(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4, 4)):
|
||||
super(MelganDiscriminator, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
layer_kernel_size = np.prod(kernel_sizes)
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
|
||||
# initial layer
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(in_channels,
|
||||
base_channels,
|
||||
layer_kernel_size,
|
||||
stride=1)), nn.LeakyReLU(0.2, inplace=True))
|
||||
]
|
||||
|
||||
# downsampling layers
|
||||
layer_in_channels = base_channels
|
||||
for idx, downsample_factor in enumerate(downsample_factors):
|
||||
layer_out_channels = min(layer_in_channels * downsample_factor,
|
||||
max_channels)
|
||||
layer_kernel_size = downsample_factor * 10 + 1
|
||||
layer_padding = (layer_kernel_size - 1) // 2
|
||||
layer_groups = layer_in_channels // 4
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(layer_in_channels,
|
||||
layer_out_channels,
|
||||
kernel_size=layer_kernel_size,
|
||||
stride=downsample_factor,
|
||||
padding=layer_padding,
|
||||
groups=layer_groups)),
|
||||
nn.LeakyReLU(0.2, inplace=True))
|
||||
]
|
||||
layer_in_channels = layer_out_channels
|
||||
|
||||
# last 2 layers
|
||||
layer_padding1 = (kernel_sizes[0] - 1) // 2
|
||||
layer_padding2 = (kernel_sizes[1] - 1) // 2
|
||||
self.layers += [
|
||||
nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(layer_out_channels,
|
||||
layer_out_channels,
|
||||
kernel_size=kernel_sizes[0],
|
||||
stride=1,
|
||||
padding=layer_padding1)),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv1d(layer_out_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_sizes[1],
|
||||
stride=1,
|
||||
padding=layer_padding2)),
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
feats = []
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
feats.append(x)
|
||||
return x, feats
|
|
@ -0,0 +1,91 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from TTS.vocoder.layers.melgan import ResidualStack
|
||||
|
||||
|
||||
class MelganGenerator(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=(8, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=3):
|
||||
super(MelganGenerator, self).__init__()
|
||||
|
||||
# assert model parameters
|
||||
assert (proj_kernel -
|
||||
1) % 2 == 0, " [!] proj_kernel should be an odd number."
|
||||
|
||||
# setup additional model parameters
|
||||
base_padding = (proj_kernel - 1) // 2
|
||||
act_slope = 0.2
|
||||
self.inference_padding = 2
|
||||
|
||||
# initial layer
|
||||
layers = []
|
||||
layers += [
|
||||
nn.ReflectionPad1d(base_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(in_channels,
|
||||
base_channels,
|
||||
kernel_size=proj_kernel,
|
||||
stride=1,
|
||||
bias=True))
|
||||
]
|
||||
|
||||
# upsampling layers and residual stacks
|
||||
for idx, upsample_factor in enumerate(upsample_factors):
|
||||
layer_in_channels = base_channels // (2**idx)
|
||||
layer_out_channels = base_channels // (2**(idx + 1))
|
||||
layer_filter_size = upsample_factor * 2
|
||||
layer_stride = upsample_factor
|
||||
layer_output_padding = upsample_factor % 2
|
||||
layer_padding = upsample_factor // 2 + layer_output_padding
|
||||
layers += [
|
||||
nn.LeakyReLU(act_slope),
|
||||
weight_norm(
|
||||
nn.ConvTranspose1d(layer_in_channels,
|
||||
layer_out_channels,
|
||||
layer_filter_size,
|
||||
stride=layer_stride,
|
||||
padding=layer_padding,
|
||||
output_padding=layer_output_padding,
|
||||
bias=True)),
|
||||
ResidualStack(
|
||||
channels=layer_out_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
kernel_size=res_kernel
|
||||
)
|
||||
]
|
||||
|
||||
layers += [nn.LeakyReLU(act_slope)]
|
||||
|
||||
# final layer
|
||||
layers += [
|
||||
nn.ReflectionPad1d(base_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(layer_out_channels,
|
||||
out_channels,
|
||||
proj_kernel,
|
||||
stride=1,
|
||||
bias=True)),
|
||||
nn.Tanh()
|
||||
]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, cond_features):
|
||||
return self.layers(cond_features)
|
||||
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features,
|
||||
(self.inference_padding, self.inference_padding),
|
||||
'replicate')
|
||||
return self.layers(cond_features)
|
|
@ -0,0 +1,41 @@
|
|||
from torch import nn
|
||||
|
||||
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
|
||||
|
||||
|
||||
class MelganMultiscaleDiscriminator(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
num_scales=3,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=16,
|
||||
max_channels=1024,
|
||||
downsample_factors=(4, 4, 4, 4),
|
||||
pooling_kernel_size=4,
|
||||
pooling_stride=2,
|
||||
pooling_padding=1):
|
||||
super(MelganMultiscaleDiscriminator, self).__init__()
|
||||
|
||||
self.discriminators = nn.ModuleList([
|
||||
MelganDiscriminator(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_sizes=kernel_sizes,
|
||||
base_channels=base_channels,
|
||||
max_channels=max_channels,
|
||||
downsample_factors=downsample_factors)
|
||||
for _ in range(num_scales)
|
||||
])
|
||||
|
||||
self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
scores = list()
|
||||
feats = list()
|
||||
for disc in self.discriminators:
|
||||
score, feat = disc(x)
|
||||
scores.append(score)
|
||||
feats.append(feat)
|
||||
x = self.pooling(x)
|
||||
return scores, feats
|
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
from TTS.vocoder.layers.pqmf import PQMF
|
||||
|
||||
|
||||
class MultibandMelganGenerator(MelganGenerator):
|
||||
def __init__(self,
|
||||
in_channels=80,
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=(2, 8, 2, 2),
|
||||
res_kernel=3,
|
||||
num_res_blocks=3):
|
||||
super(MultibandMelganGenerator,
|
||||
self).__init__(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
proj_kernel=proj_kernel,
|
||||
base_channels=base_channels,
|
||||
upsample_factors=upsample_factors,
|
||||
res_kernel=res_kernel,
|
||||
num_res_blocks=num_res_blocks)
|
||||
self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
|
||||
|
||||
def pqmf_analysis(self, x):
|
||||
return self.pqmf_layer.analysis(x)
|
||||
|
||||
def pqmf_synthesis(self, x):
|
||||
return self.pqmf_layer.synthesis(x)
|
||||
|
||||
def inference(self, cond_features):
|
||||
cond_features = cond_features.to(self.layers[1].weight.device)
|
||||
cond_features = torch.nn.functional.pad(
|
||||
cond_features,
|
||||
(self.inference_padding, self.inference_padding),
|
||||
'replicate')
|
||||
return self.pqmf.synthesis(self.layers(cond_features))
|
|
@ -0,0 +1,225 @@
|
|||
import numpy as np
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GBlock(nn.Module):
|
||||
def __init__(self, in_channels, cond_channels, downsample_factor):
|
||||
super(GBlock, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.downsample_factor = downsample_factor
|
||||
|
||||
self.start = nn.Sequential(
|
||||
nn.AvgPool1d(downsample_factor, stride=downsample_factor),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1))
|
||||
self.lc_conv1d = nn.Conv1d(cond_channels,
|
||||
in_channels * 2,
|
||||
kernel_size=1)
|
||||
self.end = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels * 2,
|
||||
in_channels * 2,
|
||||
kernel_size=3,
|
||||
dilation=2,
|
||||
padding=2))
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv1d(in_channels, in_channels * 2, kernel_size=1),
|
||||
nn.AvgPool1d(downsample_factor, stride=downsample_factor))
|
||||
|
||||
def forward(self, inputs, conditions):
|
||||
outputs = self.start(inputs) + self.lc_conv1d(conditions)
|
||||
outputs = self.end(outputs)
|
||||
residual_outputs = self.residual(inputs)
|
||||
outputs = outputs + residual_outputs
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class DBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, downsample_factor):
|
||||
super(DBlock, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.downsample_factor = downsample_factor
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.donwsample_layer = nn.AvgPool1d(downsample_factor,
|
||||
stride=downsample_factor)
|
||||
self.layers = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
dilation=2,
|
||||
padding=2))
|
||||
self.residual = nn.Sequential(
|
||||
nn.Conv1d(in_channels, out_channels, kernel_size=1), )
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.downsample_factor > 1:
|
||||
outputs = self.layers(self.donwsample_layer(inputs))\
|
||||
+ self.donwsample_layer(self.residual(inputs))
|
||||
else:
|
||||
outputs = self.layers(inputs) + self.residual(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class ConditionalDiscriminator(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
downsample_factors=(2, 2, 2),
|
||||
out_channels=(128, 256)):
|
||||
super(ConditionalDiscriminator, self).__init__()
|
||||
|
||||
assert len(downsample_factors) == len(out_channels) + 1
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.downsample_factors = downsample_factors
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.pre_cond_layers = nn.ModuleList()
|
||||
self.post_cond_layers = nn.ModuleList()
|
||||
|
||||
# layers before condition features
|
||||
self.pre_cond_layers += [DBlock(in_channels, 64, 1)]
|
||||
in_channels = 64
|
||||
for (i, channel) in enumerate(out_channels):
|
||||
self.pre_cond_layers.append(
|
||||
DBlock(in_channels, channel, downsample_factors[i]))
|
||||
in_channels = channel
|
||||
|
||||
# condition block
|
||||
self.cond_block = GBlock(in_channels, cond_channels,
|
||||
downsample_factors[-1])
|
||||
|
||||
# layers after condition block
|
||||
self.post_cond_layers += [
|
||||
DBlock(in_channels * 2, in_channels * 2, 1),
|
||||
DBlock(in_channels * 2, in_channels * 2, 1),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Conv1d(in_channels * 2, 1, kernel_size=1),
|
||||
]
|
||||
|
||||
def forward(self, inputs, conditions):
|
||||
batch_size = inputs.size()[0]
|
||||
outputs = inputs.view(batch_size, self.in_channels, -1)
|
||||
for layer in self.pre_cond_layers:
|
||||
outputs = layer(outputs)
|
||||
outputs = self.cond_block(outputs, conditions)
|
||||
for layer in self.post_cond_layers:
|
||||
outputs = layer(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class UnconditionalDiscriminator(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
base_channels=64,
|
||||
downsample_factors=(8, 4),
|
||||
out_channels=(128, 256)):
|
||||
super(UnconditionalDiscriminator, self).__init__()
|
||||
|
||||
self.downsample_factors = downsample_factors
|
||||
self.in_channels = in_channels
|
||||
self.downsample_factors = downsample_factors
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
self.layers += [DBlock(self.in_channels, base_channels, 1)]
|
||||
in_channels = base_channels
|
||||
for (i, factor) in enumerate(downsample_factors):
|
||||
self.layers.append(DBlock(in_channels, out_channels[i], factor))
|
||||
in_channels *= 2
|
||||
self.layers += [
|
||||
DBlock(in_channels, in_channels, 1),
|
||||
DBlock(in_channels, in_channels, 1),
|
||||
nn.AdaptiveAvgPool1d(1),
|
||||
nn.Conv1d(in_channels, 1, kernel_size=1),
|
||||
]
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size()[0]
|
||||
outputs = inputs.view(batch_size, self.in_channels, -1)
|
||||
for layer in self.layers:
|
||||
outputs = layer(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class RandomWindowDiscriminator(nn.Module):
|
||||
"""Random Window Discriminator as described in
|
||||
http://arxiv.org/abs/1909.11646"""
|
||||
def __init__(self,
|
||||
cond_channels,
|
||||
hop_length,
|
||||
uncond_disc_donwsample_factors=(8, 4),
|
||||
cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2),
|
||||
(8, 4, 2), (8, 4), (4, 2, 2)),
|
||||
cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256),
|
||||
(128, 256), (256, ), (128, 256)),
|
||||
window_sizes=(512, 1024, 2048, 4096, 8192)):
|
||||
|
||||
super(RandomWindowDiscriminator, self).__init__()
|
||||
self.cond_channels = cond_channels
|
||||
self.window_sizes = window_sizes
|
||||
self.hop_length = hop_length
|
||||
self.base_window_size = self.hop_length * 2
|
||||
self.ks = [ws // self.base_window_size for ws in window_sizes]
|
||||
|
||||
# check arguments
|
||||
assert len(cond_disc_downsample_factors) == len(
|
||||
cond_disc_out_channels) == len(window_sizes)
|
||||
for ws in window_sizes:
|
||||
assert ws % hop_length == 0
|
||||
|
||||
for idx, cf in enumerate(cond_disc_downsample_factors):
|
||||
assert np.prod(cf) == hop_length // self.ks[idx]
|
||||
|
||||
# define layers
|
||||
self.unconditional_discriminators = nn.ModuleList([])
|
||||
for k in self.ks:
|
||||
layer = UnconditionalDiscriminator(
|
||||
in_channels=k,
|
||||
base_channels=64,
|
||||
downsample_factors=uncond_disc_donwsample_factors)
|
||||
self.unconditional_discriminators.append(layer)
|
||||
|
||||
self.conditional_discriminators = nn.ModuleList([])
|
||||
for idx, k in enumerate(self.ks):
|
||||
layer = ConditionalDiscriminator(
|
||||
in_channels=k,
|
||||
cond_channels=cond_channels,
|
||||
downsample_factors=cond_disc_downsample_factors[idx],
|
||||
out_channels=cond_disc_out_channels[idx])
|
||||
self.conditional_discriminators.append(layer)
|
||||
|
||||
def forward(self, x, c):
|
||||
scores = []
|
||||
feats = []
|
||||
# unconditional pass
|
||||
for (window_size, layer) in zip(self.window_sizes,
|
||||
self.unconditional_discriminators):
|
||||
index = np.random.randint(x.shape[-1] - window_size)
|
||||
|
||||
score = layer(x[:, :, index:index + window_size])
|
||||
scores.append(score)
|
||||
|
||||
# conditional pass
|
||||
for (window_size, layer) in zip(self.window_sizes,
|
||||
self.conditional_discriminators):
|
||||
frame_size = window_size // self.hop_length
|
||||
lc_index = np.random.randint(c.shape[-1] - frame_size)
|
||||
sample_index = lc_index * self.hop_length
|
||||
x_sub = x[:, :,
|
||||
sample_index:(lc_index + frame_size) * self.hop_length]
|
||||
c_sub = c[:, :, lc_index:lc_index + frame_size]
|
||||
|
||||
score = layer(x_sub, c_sub)
|
||||
scores.append(score)
|
||||
return scores, feats
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
Binary file not shown.
|
@ -0,0 +1 @@
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
{
|
||||
"audio":{
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"num_freq": 513, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 22050, // wav sample-rate. If different than the original data, it is resampled.
|
||||
"frame_length_ms": null, // stft window length in ms.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms.
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"preemphasis": 0.97, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"min_level_db": -100, // normalization range
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 30,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"mel_fmin": 0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"do_trim_silence": false
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
import os
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
OUTPATH = os.path.join(file_path, "../../tests/outputs/loader_tests/")
|
||||
os.makedirs(OUTPATH, exist_ok=True)
|
||||
|
||||
C = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
test_data_path = os.path.join(file_path, "../../tests/data/ljspeech/")
|
||||
ok_ljspeech = os.path.exists(test_data_path)
|
||||
|
||||
|
||||
|
||||
def gan_dataset_case(batch_size, seq_len, hop_len, conv_pad, return_segments, use_noise_augment, use_cache, num_workers):
|
||||
''' run dataloader with given parameters and check conditions '''
|
||||
ap = AudioProcessor(**C.audio)
|
||||
eval_items, train_items = load_wav_data(test_data_path, 10)
|
||||
dataset = GANDataset(ap,
|
||||
train_items,
|
||||
seq_len=seq_len,
|
||||
hop_len=hop_len,
|
||||
pad_short=2000,
|
||||
conv_pad=conv_pad,
|
||||
return_segments=return_segments,
|
||||
use_noise_augment=use_noise_augment,
|
||||
use_cache=use_cache)
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True)
|
||||
|
||||
max_iter = 10
|
||||
count_iter = 0
|
||||
|
||||
# return random segments or return the whole audio
|
||||
if return_segments:
|
||||
for item1, item2 in loader:
|
||||
feat1, wav1 = item1
|
||||
feat2, wav2 = item2
|
||||
expected_feat_shape = (batch_size, ap.num_mels, seq_len // hop_len + conv_pad * 2)
|
||||
|
||||
# check shapes
|
||||
assert np.all(feat1.shape == expected_feat_shape), f" [!] {feat1.shape} vs {expected_feat_shape}"
|
||||
assert (feat1.shape[2] - conv_pad * 2) * hop_len == wav1.shape[2]
|
||||
|
||||
# check feature vs audio match
|
||||
if not use_noise_augment:
|
||||
for idx in range(batch_size):
|
||||
audio = wav1[idx].squeeze()
|
||||
feat = feat1[idx]
|
||||
mel = ap.melspectrogram(audio)
|
||||
# the first 2 and the last frame is skipped due to the padding
|
||||
# applied in spec. computation.
|
||||
assert (feat - mel[:, :feat1.shape[-1]])[:, 2:-1].sum() == 0, f' [!] {(feat - mel[:, :feat1.shape[-1]])[:, 2:-1].sum()}'
|
||||
|
||||
count_iter += 1
|
||||
# if count_iter == max_iter:
|
||||
# break
|
||||
else:
|
||||
for item in loader:
|
||||
feat, wav = item
|
||||
expected_feat_shape = (batch_size, ap.num_mels, (wav.shape[-1] // hop_len) + (conv_pad * 2))
|
||||
assert np.all(feat.shape == expected_feat_shape), f" [!] {feat.shape} vs {expected_feat_shape}"
|
||||
assert (feat.shape[2] - conv_pad * 2) * hop_len == wav.shape[2]
|
||||
count_iter += 1
|
||||
if count_iter == max_iter:
|
||||
break
|
||||
|
||||
|
||||
def test_parametrized_gan_dataset():
|
||||
''' test dataloader with different parameters '''
|
||||
params = [
|
||||
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 0],
|
||||
[32, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 4],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, True, 0],
|
||||
[1, C.audio['hop_length'], C.audio['hop_length'], 0, True, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 2, True, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, True, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, False, True, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, True, True, False, 0],
|
||||
[1, C.audio['hop_length'] * 10, C.audio['hop_length'], 0, False, False, False, 0],
|
||||
]
|
||||
for param in params:
|
||||
print(param)
|
||||
gan_dataset_case(*param)
|
|
@ -0,0 +1,62 @@
|
|||
import os
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.layers.losses import TorchSTFT, STFTLoss, MultiScaleSTFTLoss
|
||||
|
||||
from TTS.tests import get_tests_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
|
||||
OUT_PATH = os.path.join(get_tests_output_path(), "audio_tests")
|
||||
os.makedirs(OUT_PATH, exist_ok=True)
|
||||
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
C = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
ap = AudioProcessor(**C.audio)
|
||||
|
||||
|
||||
def test_torch_stft():
|
||||
torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length)
|
||||
# librosa stft
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
M_librosa = abs(ap._stft(wav))
|
||||
# torch stft
|
||||
wav = torch.from_numpy(wav[None, :]).float()
|
||||
M_torch = torch_stft(wav)
|
||||
# check the difference b/w librosa and torch outputs
|
||||
assert (M_librosa - M_torch[0].data.numpy()).max() < 1e-5
|
||||
|
||||
|
||||
def test_stft_loss():
|
||||
stft_loss = STFTLoss(ap.n_fft, ap.hop_length, ap.win_length)
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
wav = torch.from_numpy(wav[None, :]).float()
|
||||
loss_m, loss_sc = stft_loss(wav, wav)
|
||||
assert loss_m + loss_sc == 0
|
||||
loss_m, loss_sc = stft_loss(wav, torch.rand_like(wav))
|
||||
assert loss_sc < 1.0
|
||||
assert loss_m + loss_sc > 0
|
||||
|
||||
|
||||
def test_multiscale_stft_loss():
|
||||
stft_loss = MultiScaleSTFTLoss([ap.n_fft//2, ap.n_fft, ap.n_fft*2],
|
||||
[ap.hop_length // 2, ap.hop_length, ap.hop_length * 2],
|
||||
[ap.win_length // 2, ap.win_length, ap.win_length * 2])
|
||||
wav = ap.load_wav(WAV_FILE)
|
||||
wav = torch.from_numpy(wav[None, :]).float()
|
||||
loss_m, loss_sc = stft_loss(wav, wav)
|
||||
assert loss_m + loss_sc == 0
|
||||
loss_m, loss_sc = stft_loss(wav, torch.rand_like(wav))
|
||||
assert loss_sc < 1.0
|
||||
assert loss_m + loss_sc > 0
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator
|
||||
from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator
|
||||
|
||||
|
||||
def test_melgan_discriminator():
|
||||
model = MelganDiscriminator()
|
||||
print(model)
|
||||
dummy_input = torch.rand((4, 1, 256 * 10))
|
||||
output, _ = model(dummy_input)
|
||||
assert np.all(output.shape == (4, 1, 10))
|
||||
|
||||
|
||||
def test_melgan_multi_scale_discriminator():
|
||||
model = MelganMultiscaleDiscriminator()
|
||||
print(model)
|
||||
dummy_input = torch.rand((4, 1, 256 * 16))
|
||||
scores, feats = model(dummy_input)
|
||||
assert len(scores) == 3
|
||||
assert len(scores) == len(feats)
|
||||
assert np.all(scores[0].shape == (4, 1, 16))
|
||||
assert np.all(feats[0][0].shape == (4, 16, 4096))
|
||||
assert np.all(feats[0][1].shape == (4, 64, 1024))
|
||||
assert np.all(feats[0][2].shape == (4, 256, 256))
|
|
@ -0,0 +1,15 @@
|
|||
import numpy as np
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from TTS.vocoder.models.melgan_generator import MelganGenerator
|
||||
|
||||
def test_melgan_generator():
|
||||
model = MelganGenerator()
|
||||
print(model)
|
||||
dummy_input = torch.rand((4, 80, 64))
|
||||
output = model(dummy_input)
|
||||
assert np.all(output.shape == (4, 1, 64 * 256))
|
||||
output = model.inference(dummy_input)
|
||||
assert np.all(output.shape == (4, 1, (64 + 4) * 256))
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
import os
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from librosa.core import load
|
||||
|
||||
from TTS.tests import get_tests_path, get_tests_input_path, get_tests_output_path
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.vocoder.layers.pqmf import PQMF
|
||||
from TTS.vocoder.layers.pqmf2 import PQMF as PQMF2
|
||||
|
||||
|
||||
TESTS_PATH = get_tests_path()
|
||||
WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
|
||||
|
||||
|
||||
def test_pqmf():
|
||||
w, sr = load(WAV_FILE)
|
||||
|
||||
layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
|
||||
w, sr = load(WAV_FILE)
|
||||
w2 = torch.from_numpy(w[None, None, :])
|
||||
b2 = layer.analysis(w2)
|
||||
w2_ = layer.synthesis(b2)
|
||||
|
||||
print(w2_.max())
|
||||
print(w2_.min())
|
||||
print(w2_.mean())
|
||||
sf.write('pqmf_output.wav', w2_.flatten().detach(), sr)
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from TTS.vocoder.models.random_window_discriminator import RandomWindowDiscriminator
|
||||
|
||||
|
||||
def test_rwd():
|
||||
layer = RandomWindowDiscriminator(cond_channels=80,
|
||||
window_sizes=(512, 1024, 2048, 4096,
|
||||
8192),
|
||||
cond_disc_downsample_factors=[
|
||||
(8, 4, 2, 2, 2), (8, 4, 2, 2),
|
||||
(8, 4, 2), (8, 4), (4, 2, 2)
|
||||
],
|
||||
hop_length=256)
|
||||
x = torch.rand([4, 1, 22050])
|
||||
c = torch.rand([4, 80, 22050 // 256])
|
||||
|
||||
scores, _ = layer(x, c)
|
||||
assert len(scores) == 10
|
||||
assert np.all(scores[0].shape == (4, 1, 1))
|
|
@ -0,0 +1,585 @@
|
|||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from inspect import signature
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import NoamLR
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
# from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
# init_distributed, reduce_tensor)
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.utils.io import save_checkpoint, save_best_model
|
||||
from TTS.vocoder.utils.console_logger import ConsoleLogger
|
||||
from TTS.vocoder.utils.generic_utils import (check_config, plot_results,
|
||||
setup_discriminator,
|
||||
setup_generator)
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
|
||||
|
||||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = GANDataset(ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
is_training=not is_val,
|
||||
return_segments=False if is_val else True,
|
||||
use_noise_augment=c.use_noise_augment,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose)
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=1 if is_val else c.batch_size,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
sampler=None,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
if isinstance(data[0], list):
|
||||
# setup input data
|
||||
c_G, x_G = data[0]
|
||||
c_D, x_D = data[1]
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
c_G = c_G.cuda(non_blocking=True)
|
||||
x_G = x_G.cuda(non_blocking=True)
|
||||
c_D = c_D.cuda(non_blocking=True)
|
||||
x_D = x_D.cuda(non_blocking=True)
|
||||
|
||||
return c_G, x_G, c_D, x_D
|
||||
|
||||
# return a whole audio segment
|
||||
c, x = data
|
||||
if use_cuda:
|
||||
c = c.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
return c, x
|
||||
|
||||
|
||||
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||
scheduler_G, scheduler_D, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model_G.train()
|
||||
model_D.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
c_G, y_G, c_D, y_D = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
# get current learning rates
|
||||
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
|
||||
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
|
||||
|
||||
##############################
|
||||
# GENERATOR
|
||||
##############################
|
||||
|
||||
# generator pass
|
||||
optimizer_G.zero_grad()
|
||||
y_hat = model_G(c_G)
|
||||
|
||||
in_real_D = y_hat
|
||||
in_fake_D = y_G
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
in_real_D = y_G
|
||||
in_fake_D = model_G.pqmf_synthesis(y_hat)
|
||||
y_G = model_G.pqmf_analysis(y_G)
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y_G = y_G.view(-1, 1, y_G.shape[2])
|
||||
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D).parameters) == 2:
|
||||
D_out_fake = model_D(in_fake_D, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(in_fake_D)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(in_real_D)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
else:
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
feats_real)
|
||||
loss_G = loss_G_dict['G_loss']
|
||||
|
||||
# optimizer generator
|
||||
loss_G.backward()
|
||||
if c.gen_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
||||
c.gen_clip_grad)
|
||||
optimizer_G.step()
|
||||
|
||||
# setup lr
|
||||
if c.noam_schedule:
|
||||
scheduler_G.step()
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
##############################
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
# discriminator pass
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_D)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
|
||||
optimizer_D.zero_grad()
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat.detach(), c_D)
|
||||
D_out_real = model_D(y_D, c_D)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat.detach())
|
||||
D_out_real = model_D(y_D)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||
loss_D = loss_D_dict['D_loss']
|
||||
|
||||
# optimizer discriminator
|
||||
loss_D.backward()
|
||||
if c.disc_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
||||
c.disc_clip_grad)
|
||||
optimizer_D.step()
|
||||
|
||||
# setup lr
|
||||
if c.noam_schedule:
|
||||
scheduler_D.step()
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
step_time, loader_time, current_lr_G,
|
||||
loss_dict, keep_avg.avg_values)
|
||||
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {
|
||||
"lr_G": current_lr_G,
|
||||
"lr_D": current_lr_D,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model_G,
|
||||
optimizer_G,
|
||||
model_D,
|
||||
optimizer_D,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(in_fake_D, in_real_D, ap, global_step,
|
||||
'train')
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = in_fake_D[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'train/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
# TODO: plot model stats
|
||||
# if c.tb_model_param_stats:
|
||||
# tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||
model_G.eval()
|
||||
model_D.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
end_time = time.time()
|
||||
c_logger.print_eval_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
c_G, y_G = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
##############################
|
||||
# GENERATOR
|
||||
##############################
|
||||
|
||||
# generator pass
|
||||
y_hat = model_G(c_G)
|
||||
|
||||
in_real_D = y_hat
|
||||
in_fake_D = y_G
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
in_real_D = y_G
|
||||
in_fake_D = model_G.pqmf_synthesis(y_hat)
|
||||
y_G = model_G.pqmf_analysis(y_G)
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y_G = y_G.view(-1, 1, y_G.shape[2])
|
||||
|
||||
D_out_fake = model_D(in_fake_D)
|
||||
D_out_real = None
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(in_real_D)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
feats_real)
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
update_eval_values['avg_' + key] = value.item()
|
||||
update_eval_values['avg_loader_time'] = loader_time
|
||||
update_eval_values['avg_step_time'] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
|
||||
# synthesize a full voice
|
||||
data_loader.return_segments = False
|
||||
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global train_data, eval_data
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
# DISTRUBUTED
|
||||
# if num_gpus > 1:
|
||||
# init_distributed(args.rank, num_gpus, args.group_id,
|
||||
# c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# setup models
|
||||
model_gen = setup_generator(c)
|
||||
model_disc = setup_discriminator(c)
|
||||
|
||||
# setup optimizers
|
||||
optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
|
||||
optimizer_disc = RAdam(model_disc.parameters(),
|
||||
lr=c.lr_disc,
|
||||
weight_decay=0)
|
||||
|
||||
# setup criterion
|
||||
criterion_gen = GeneratorLoss(c)
|
||||
criterion_disc = DiscriminatorLoss(c)
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
try:
|
||||
model_gen.load_state_dict(checkpoint['model'])
|
||||
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
||||
model_disc.load_state_dict(checkpoint['model_disc'])
|
||||
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
||||
except:
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model_gen.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_gen.load_state_dict(model_dict)
|
||||
|
||||
model_dict = model_disc.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
|
||||
model_disc.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# reset lr if not countinuining training.
|
||||
for group in optimizer_gen.param_groups:
|
||||
group['lr'] = c.lr_gen
|
||||
|
||||
for group in optimizer_disc.param_groups:
|
||||
group['lr'] = c.lr_disc
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if use_cuda:
|
||||
model_gen.cuda()
|
||||
criterion_gen.cuda()
|
||||
model_disc.cuda()
|
||||
criterion_disc.cuda()
|
||||
|
||||
# DISTRUBUTED
|
||||
# if num_gpus > 1:
|
||||
# model = apply_gradient_allreduce(model)
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler_gen = NoamLR(optimizer_gen,
|
||||
warmup_steps=c.warmup_steps_gen,
|
||||
last_epoch=args.restore_step - 1)
|
||||
scheduler_disc = NoamLR(optimizer_disc,
|
||||
warmup_steps=c.warmup_steps_gen,
|
||||
last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler_gen, scheduler_disc = None, None
|
||||
|
||||
num_params = count_parameters(model_gen)
|
||||
print(" > Generator has {} parameters".format(num_params), flush=True)
|
||||
num_params = count_parameters(model_disc)
|
||||
print(" > Discriminator has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if 'best_loss' not in locals():
|
||||
best_loss = float('inf')
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model_gen, criterion_gen, optimizer_gen,
|
||||
model_disc, criterion_disc, optimizer_disc,
|
||||
scheduler_gen, scheduler_disc, ap, global_step,
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, ap,
|
||||
global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||
best_loss = save_best_model(target_loss,
|
||||
best_loss,
|
||||
model_gen,
|
||||
optimizer_gen,
|
||||
model_disc,
|
||||
optimizer_disc,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=eval_avg_loss_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help=
|
||||
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument('--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='DISTRIBUTED: process rank for distributed training.')
|
||||
parser.add_argument('--group_id',
|
||||
type=str,
|
||||
default="",
|
||||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != '':
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(
|
||||
args.continue_path +
|
||||
"/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
check_config(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
||||
args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_config_file(args.config_path,
|
||||
os.path.join(OUT_PATH, 'config.json'), new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -0,0 +1,97 @@
|
|||
import datetime
|
||||
from TTS.utils.io import AttrDict
|
||||
|
||||
|
||||
tcolors = AttrDict({
|
||||
'OKBLUE': '\033[94m',
|
||||
'HEADER': '\033[95m',
|
||||
'OKGREEN': '\033[92m',
|
||||
'WARNING': '\033[93m',
|
||||
'FAIL': '\033[91m',
|
||||
'ENDC': '\033[0m',
|
||||
'BOLD': '\033[1m',
|
||||
'UNDERLINE': '\033[4m'
|
||||
})
|
||||
|
||||
|
||||
class ConsoleLogger():
|
||||
def __init__(self):
|
||||
# TODO: color code for value changes
|
||||
# use these to compare values between iterations
|
||||
self.old_train_loss_dict = None
|
||||
self.old_epoch_loss_dict = None
|
||||
self.old_eval_loss_dict = None
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def get_time(self):
|
||||
now = datetime.datetime.now()
|
||||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def print_epoch_start(self, epoch, max_epoch):
|
||||
print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD,
|
||||
epoch, max_epoch, tcolors.ENDC),
|
||||
flush=True)
|
||||
|
||||
def print_train_start(self):
|
||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||
|
||||
def print_train_step(self, batch_steps, step, global_step,
|
||||
step_time, loader_time, lr,
|
||||
loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
print()
|
||||
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
|
||||
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
|
||||
for key, value in loss_dict.items():
|
||||
# print the avg value if given
|
||||
if f'avg_{key}' in avg_loss_dict.keys():
|
||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
||||
else:
|
||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||
log_text += f"{indent}step_time: {step_time:.2f}\n{indent}loader_time: {loader_time:.2f}\n{indent}lr: {lr:.5f}"
|
||||
print(log_text, flush=True)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def print_train_epoch_end(self, global_step, epoch, epoch_time,
|
||||
print_dict):
|
||||
indent = " | > "
|
||||
log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n"
|
||||
for key, value in print_dict.items():
|
||||
log_text += "{}{}: {:.5f}\n".format(indent, key, value)
|
||||
print(log_text, flush=True)
|
||||
|
||||
def print_eval_start(self):
|
||||
print(f"{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n")
|
||||
|
||||
def print_eval_step(self, step, loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
print()
|
||||
log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n"
|
||||
for key, value in loss_dict.items():
|
||||
# print the avg value if given
|
||||
if f'avg_{key}' in avg_loss_dict.keys():
|
||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
||||
else:
|
||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||
print(log_text, flush=True)
|
||||
|
||||
def print_epoch_end(self, epoch, avg_loss_dict):
|
||||
indent = " | > "
|
||||
log_text = " {}--> EVAL PERFORMANCE{}\n".format(
|
||||
tcolors.BOLD, tcolors.ENDC)
|
||||
for key, value in avg_loss_dict.items():
|
||||
# print the avg value if given
|
||||
color = ''
|
||||
sign = '+'
|
||||
diff = 0
|
||||
if self.old_eval_loss_dict is not None:
|
||||
diff = value - self.old_eval_loss_dict[key]
|
||||
if diff < 0:
|
||||
color = tcolors.OKGREEN
|
||||
sign = ''
|
||||
elif diff > 0:
|
||||
color = tcolors.FAIL
|
||||
sing = '+'
|
||||
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
|
||||
self.old_eval_loss_dict = avg_loss_dict
|
||||
print(log_text, flush=True)
|
|
@ -0,0 +1,102 @@
|
|||
import re
|
||||
import importlib
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from TTS.utils.visual import plot_spectrogram
|
||||
|
||||
|
||||
def plot_results(y_hat, y, ap, global_step, name_prefix):
|
||||
""" Plot vocoder model results """
|
||||
|
||||
# select an instance from batch
|
||||
y_hat = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
y = y[0].squeeze(0).detach().cpu().numpy()
|
||||
|
||||
spec_fake = ap.spectrogram(y_hat).T
|
||||
spec_real = ap.spectrogram(y).T
|
||||
spec_diff = np.abs(spec_fake - spec_real)
|
||||
|
||||
# plot figure and save it
|
||||
fig_wave = plt.figure()
|
||||
plt.subplot(2, 1, 1)
|
||||
plt.plot(y)
|
||||
plt.title("groundtruth speech")
|
||||
plt.subplot(2, 1, 2)
|
||||
plt.plot(y_hat)
|
||||
plt.title(f"generated speech @ {global_step} steps")
|
||||
plt.tight_layout()
|
||||
plt.close()
|
||||
|
||||
figures = {
|
||||
name_prefix + "/spectrogram/fake": plot_spectrogram(spec_fake, ap),
|
||||
name_prefix + "spectrogram/real": plot_spectrogram(spec_real, ap),
|
||||
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff, ap),
|
||||
name_prefix + "speech_comparison": fig_wave,
|
||||
}
|
||||
return figures
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
if c.generator_model in 'melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
if c.generator_model in 'melgan_fb_generator':
|
||||
pass
|
||||
if c.generator_model in 'multiband_melgan_generator':
|
||||
model = MyModel(
|
||||
in_channels=c.audio['num_mels'],
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
MyModel = importlib.import_module('TTS.vocoder.models.' +
|
||||
c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model))
|
||||
if c.discriminator_model in 'random_window_discriminator':
|
||||
model = MyModel(
|
||||
cond_channels=c.audio['num_mels'],
|
||||
hop_length=c.audio['hop_length'],
|
||||
uncond_disc_donwsample_factors=c.
|
||||
discriminator_model_params['uncond_disc_donwsample_factors'],
|
||||
cond_disc_downsample_factors=c.
|
||||
discriminator_model_params['cond_disc_downsample_factors'],
|
||||
cond_disc_out_channels=c.
|
||||
discriminator_model_params['cond_disc_out_channels'],
|
||||
window_sizes=c.discriminator_model_params['window_sizes'])
|
||||
if c.discriminator_model in 'melgan_multiscale_discriminator':
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params['base_channels'],
|
||||
max_channels=c.discriminator_model_params['max_channels'],
|
||||
downsample_factors=c.
|
||||
discriminator_model_params['downsample_factors'])
|
||||
return model
|
||||
|
||||
|
||||
# def check_config(c):
|
||||
# pass
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
import torch
|
||||
import datetime
|
||||
|
||||
|
||||
def save_model(model, optimizer, model_disc, optimizer_disc, current_step,
|
||||
epoch, output_path, **kwargs):
|
||||
model_state = model.state_dict()
|
||||
model_disc_state = model_disc.state_dict()
|
||||
optimizer_state = optimizer.state_dict() if optimizer is not None else None
|
||||
optimizer_disc_state = optimizer_disc.state_dict(
|
||||
) if optimizer_disc is not None else None
|
||||
state = {
|
||||
'model': model_state,
|
||||
'optimizer': optimizer_state,
|
||||
'model_disc': model_disc_state,
|
||||
'optimizer_disc': optimizer_disc_state,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
state.update(kwargs)
|
||||
torch.save(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_disc, optimizer_disc, current_step,
|
||||
epoch, output_folder, **kwargs):
|
||||
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||
save_model(model, optimizer, model_disc, optimizer_disc, current_step,
|
||||
epoch, checkpoint_path, **kwargs)
|
||||
|
||||
|
||||
def save_best_model(target_loss, best_loss, model, optimizer, model_disc,
|
||||
optimizer_disc, current_step, epoch, output_folder,
|
||||
**kwargs):
|
||||
if target_loss < best_loss:
|
||||
file_name = 'best_model.pth.tar'
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" > BEST MODEL : {}".format(checkpoint_path))
|
||||
save_model(model,
|
||||
optimizer,
|
||||
model_disc,
|
||||
optimizer_disc,
|
||||
current_step,
|
||||
epoch,
|
||||
checkpoint_path,
|
||||
model_loss=target_loss,
|
||||
**kwargs)
|
||||
best_loss = target_loss
|
||||
return best_loss
|
Loading…
Reference in New Issue