Fixes for the vits model

This commit is contained in:
Eren Gölge 2021-08-26 17:15:09 +00:00
parent 5911eec3b1
commit 49e1181ea4
9 changed files with 44 additions and 34 deletions

View File

@ -96,7 +96,7 @@ class VitsConfig(BaseTTSConfig):
model_args: VitsArgs = field(default_factory=VitsArgs)
# optimizer
grad_clip: List[float] = field(default_factory=lambda: [5, 5])
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"

View File

@ -593,7 +593,7 @@ class VitsGeneratorLoss(nn.Module):
):
"""
Shapes:
- wavefrom: :math:`[B, 1, T]`
- waveform : :math:`[B, 1, T]`
- waveform_hat: :math:`[B, 1, T]`
- z_p: :math:`[B, C, T]`
- logs_q: :math:`[B, C, T]`
@ -651,7 +651,6 @@ class VitsDiscriminatorLoss(nn.Module):
return_dict = {}
loss_disc, _, _ = self.discriminator_loss(scores_disc_real, scores_disc_fake)
return_dict["loss_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + loss_disc
return_dict["loss_disc"] = loss_disc
loss = loss + return_dict["loss_disc"]
return_dict["loss"] = loss
return return_dict

View File

@ -2,7 +2,7 @@ import torch
from torch import nn
from torch.nn.modules.conv import Conv1d
from TTS.vocoder.models.hifigan_discriminator import MultiPeriodDiscriminator
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP, MultiPeriodDiscriminator
class DiscriminatorS(torch.nn.Module):
@ -60,18 +60,32 @@ class VitsDiscriminator(nn.Module):
def __init__(self, use_spectral_norm=False):
super().__init__()
self.sd = DiscriminatorS(use_spectral_norm=use_spectral_norm)
self.mpd = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm)
periods = [2, 3, 5, 7, 11]
def forward(self, x):
self.nets = nn.ModuleList()
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])
def forward(self, x, x_hat=None):
"""
Args:
x (Tensor): input waveform.
x (Tensor): ground truth waveform.
x_hat (Tensor): predicted waveform.
Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
scores, feats = self.mpd(x)
score_sd, feats_sd = self.sd(x)
return scores + [score_sd], feats + [feats_sd]
x_scores = []
x_hat_scores = [] if x_hat is not None else None
x_feats = []
x_hat_feats = [] if x_hat is not None else None
for net in self.nets:
x_score, x_feat = net(x)
x_scores.append(x_score)
x_feats.append(x_feat)
if x_hat is not None:
x_hat_score, x_hat_feat = net(x_hat)
x_hat_scores.append(x_hat_score)
x_hat_feats.append(x_hat_feat)
return x_scores, x_feats, x_hat_scores, x_hat_feats

View File

@ -195,7 +195,7 @@ class VitsArgs(Coqpit):
inference_noise_scale: float = 0.667
length_scale: int = 1
noise_scale_dp: float = 1.0
inference_noise_scale_dp: float = 0.8
inference_noise_scale_dp: float = 1.
max_inference_len: int = None
init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False
@ -419,11 +419,11 @@ class Vits(BaseTTS):
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
with torch.no_grad():
o_scale = torch.exp(-2 * logs_p)
# logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
# logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
# duration predictor
@ -563,8 +563,9 @@ class Vits(BaseTTS):
outputs["waveform_seg"] = wav_seg
# compute discriminator scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(outputs["model_outputs"])
_, outputs["feats_disc_real"] = self.disc(wav_seg)
outputs["scores_disc_fake"], outputs["feats_disc_fake"], _, outputs["feats_disc_real"] = self.disc(
outputs["model_outputs"], wav_seg
)
# compute losses
with autocast(enabled=False): # use float32 for the criterion
@ -587,15 +588,14 @@ class Vits(BaseTTS):
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["nll_duration"]
loss_dict["loss"] += outputs["loss_duration"]
elif optimizer_idx == 1:
# discriminator pass
outputs = {}
# compute scores and features
outputs["scores_disc_fake"], outputs["feats_disc_fake"] = self.disc(self.y_disc_cache.detach())
outputs["scores_disc_real"], outputs["feats_disc_real"] = self.disc(self.wav_seg_disc_cache)
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
@ -686,14 +686,12 @@ class Vits(BaseTTS):
Returns:
List: optimizers.
"""
self.disc.requires_grad_(False)
gen_parameters = filter(lambda p: p.requires_grad, self.parameters())
self.disc.requires_grad_(True)
optimizer1 = get_optimizer(
gen_parameters = [param for name, param in self.named_parameters() if not str.startswith(name, "disc.")]
optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)
optimizer2 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
return [optimizer1, optimizer2]
optimizer1 = get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr_disc, self.disc)
return [optimizer0, optimizer1]
def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.

View File

@ -225,6 +225,7 @@ def sequence_to_text(sequence: List, tp: Dict = None, add_blank=False, custom_sy
if custom_symbols is not None:
_symbols = custom_symbols
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
elif tp:
_symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}

View File

@ -1,8 +1,6 @@
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.autograd import Variable
def reduce_tensor(tensor, num_gpus):

View File

@ -53,7 +53,6 @@ def get_commit_hash():
# Not copying .git folder into docker container
except (subprocess.CalledProcessError, FileNotFoundError):
commit = "0000000"
print(" > Git Hash: {}".format(commit))
return commit
@ -62,7 +61,6 @@ def get_experiment_folder_path(root_path, model_name):
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
commit_hash = get_commit_hash()
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
print(" > Experiment folder: {}".format(output_folder))
return output_folder

View File

@ -3,7 +3,7 @@ from TTS.utils.logging.tensorboard_logger import TensorboardLogger
from TTS.utils.logging.wandb_logger import WandbLogger
def init_logger(config):
def init_dashboard_logger(config):
if config.dashboard_logger == "tensorboard":
dashboard_logger = TensorboardLogger(config.output_log_path, model_name=config.model)

View File

@ -29,11 +29,13 @@ class ConsoleLogger:
now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S")
def print_epoch_start(self, epoch, max_epoch):
def print_epoch_start(self, epoch, max_epoch, output_path = None):
print(
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
flush=True,
)
if output_path is not None:
print(f" --> {output_path}")
def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")