From 2620f62ea8dafbedf16482d1203134ff2b8dd7d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 27 Aug 2021 07:07:07 +0000 Subject: [PATCH] Move duration_loss inside VitsGeneratorLoss --- TTS/trainer.py | 6 ++-- TTS/tts/configs/vits_config.py | 1 + TTS/tts/layers/losses.py | 6 +++- TTS/tts/models/vits.py | 43 +++++++++++++++-------------- TTS/utils/logging/console_logger.py | 2 +- 5 files changed, 34 insertions(+), 24 deletions(-) diff --git a/TTS/trainer.py b/TTS/trainer.py index 89bc5047..6a5c925a 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -423,7 +423,7 @@ class Trainer: optimizer: Target optimizer. """ for group in optimizer.param_groups: - for p in group['params']: + for p in group["params"]: yield p @staticmethod @@ -528,7 +528,9 @@ class Trainer: scaler.scale(loss_dict["loss"]).backward() if grad_clip > 0: scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip, error_if_nonfinite=False) + grad_norm = torch.nn.utils.clip_grad_norm_( + self.master_params(optimizer), grad_clip, error_if_nonfinite=False + ) # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN if torch.isnan(grad_norm) or torch.isinf(grad_norm): grad_norm = 0 diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index 7264ef05..3bf0b13a 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -113,6 +113,7 @@ class VitsConfig(BaseTTSConfig): gen_loss_alpha: float = 1.0 feat_loss_alpha: float = 1.0 mel_loss_alpha: float = 45.0 + dur_loss_alpha: float = 1.0 # data loader params return_wav: bool = True diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index cba18674..0ce4ada9 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -524,6 +524,7 @@ class VitsGeneratorLoss(nn.Module): self.kl_loss_alpha = c.kl_loss_alpha self.gen_loss_alpha = c.gen_loss_alpha self.feat_loss_alpha = c.feat_loss_alpha + self.dur_loss_alpha = c.dur_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha self.stft = TorchSTFT( c.audio.fft_size, @@ -590,6 +591,7 @@ class VitsGeneratorLoss(nn.Module): scores_disc_fake, feats_disc_fake, feats_disc_real, + loss_duration, ): """ Shapes: @@ -615,12 +617,14 @@ class VitsGeneratorLoss(nn.Module): loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha - loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha + loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration # pass losses to the dict return_dict["loss_gen"] = loss_gen return_dict["loss_kl"] = loss_kl return_dict["loss_feat"] = loss_feat return_dict["loss_mel"] = loss_mel + return_dict["loss_duration"] = loss_duration return_dict["loss"] = loss return return_dict diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index d2ad113d..3b9d82f9 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,4 +1,6 @@ +import math from dataclasses import dataclass, field +from itertools import chain from typing import Dict, List, Tuple import torch @@ -11,8 +13,6 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor - -# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.speakers import get_speaker_manager @@ -195,7 +195,7 @@ class VitsArgs(Coqpit): inference_noise_scale: float = 0.667 length_scale: int = 1 noise_scale_dp: float = 1.0 - inference_noise_scale_dp: float = 1. + inference_noise_scale_dp: float = 1.0 max_inference_len: int = None init_discriminator: bool = True use_spectral_norm_disriminator: bool = False @@ -429,14 +429,13 @@ class Vits(BaseTTS): # duration predictor attn_durations = attn.sum(3) if self.args.use_sdp: - nll_duration = self.duration_predictor( + loss_duration = self.duration_predictor( x.detach() if self.args.detach_dp_input else x, x_mask, attn_durations, g=g.detach() if self.args.detach_dp_input and g is not None else g, ) - nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask)) - outputs["nll_duration"] = nll_duration + loss_duration = loss_duration/ torch.sum(x_mask) else: attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask log_durations = self.duration_predictor( @@ -445,7 +444,7 @@ class Vits(BaseTTS): g=g.detach() if self.args.detach_dp_input and g is not None else g, ) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) - outputs["loss_duration"] = loss_duration + outputs["loss_duration"] = loss_duration # expand prior m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) @@ -580,22 +579,17 @@ class Vits(BaseTTS): scores_disc_fake=outputs["scores_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"], feats_disc_real=outputs["feats_disc_real"], + loss_duration=outputs["loss_duration"] ) - # handle the duration loss - if self.args.use_sdp: - loss_dict["nll_duration"] = outputs["nll_duration"] - loss_dict["loss"] += outputs["nll_duration"] - else: - loss_dict["loss_duration"] = outputs["loss_duration"] - loss_dict["loss"] += outputs["loss_duration"] - elif optimizer_idx == 1: # discriminator pass outputs = {} # compute scores and features - outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache) + outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc( + self.y_disc_cache.detach(), self.wav_seg_disc_cache + ) # compute loss with autocast(enabled=False): # use float32 for the criterion @@ -686,7 +680,16 @@ class Vits(BaseTTS): Returns: List: optimizers. """ - gen_parameters = [param for name, param in self.named_parameters() if not str.startswith(name, "disc.")] + gen_parameters = chain( + self.text_encoder.parameters(), + self.posterior_encoder.parameters(), + self.flow.parameters(), + self.duration_predictor.parameters(), + self.waveform_decoder.parameters(), + ) + # add the speaker embedding layer + if hasattr(self, "emb_g"): + gen_parameters = chain(gen_parameters, self.emb_g) optimizer0 = get_optimizer( self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters ) @@ -710,9 +713,9 @@ class Vits(BaseTTS): Returns: List: Schedulers, one for each optimizer. """ - scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) - scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) - return [scheduler1, scheduler2] + scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) + scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) + return [scheduler0, scheduler1] def get_criterion(self): """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in diff --git a/TTS/utils/logging/console_logger.py b/TTS/utils/logging/console_logger.py index 0103d8b3..0c1aa862 100644 --- a/TTS/utils/logging/console_logger.py +++ b/TTS/utils/logging/console_logger.py @@ -29,7 +29,7 @@ class ConsoleLogger: now = datetime.datetime.now() return now.strftime("%Y-%m-%d %H:%M:%S") - def print_epoch_start(self, epoch, max_epoch, output_path = None): + def print_epoch_start(self, epoch, max_epoch, output_path=None): print( "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), flush=True,