Move duration_loss inside VitsGeneratorLoss

This commit is contained in:
Eren Gölge 2021-08-27 07:07:07 +00:00
parent 49e1181ea4
commit 2620f62ea8
5 changed files with 34 additions and 24 deletions

View File

@ -423,7 +423,7 @@ class Trainer:
optimizer: Target optimizer.
"""
for group in optimizer.param_groups:
for p in group['params']:
for p in group["params"]:
yield p
@staticmethod
@ -528,7 +528,9 @@ class Trainer:
scaler.scale(loss_dict["loss"]).backward()
if grad_clip > 0:
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip, error_if_nonfinite=False)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.master_params(optimizer), grad_clip, error_if_nonfinite=False
)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0

View File

@ -113,6 +113,7 @@ class VitsConfig(BaseTTSConfig):
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
# data loader params
return_wav: bool = True

View File

@ -524,6 +524,7 @@ class VitsGeneratorLoss(nn.Module):
self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha
self.stft = TorchSTFT(
c.audio.fft_size,
@ -590,6 +591,7 @@ class VitsGeneratorLoss(nn.Module):
scores_disc_fake,
feats_disc_fake,
feats_disc_real,
loss_duration,
):
"""
Shapes:
@ -615,12 +617,14 @@ class VitsGeneratorLoss(nn.Module):
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
# pass losses to the dict
return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl
return_dict["loss_feat"] = loss_feat
return_dict["loss_mel"] = loss_mel
return_dict["loss_duration"] = loss_duration
return_dict["loss"] = loss
return return_dict

View File

@ -1,4 +1,6 @@
import math
from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Tuple
import torch
@ -11,8 +13,6 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
@ -195,7 +195,7 @@ class VitsArgs(Coqpit):
inference_noise_scale: float = 0.667
length_scale: int = 1
noise_scale_dp: float = 1.0
inference_noise_scale_dp: float = 1.
inference_noise_scale_dp: float = 1.0
max_inference_len: int = None
init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False
@ -429,14 +429,13 @@ class Vits(BaseTTS):
# duration predictor
attn_durations = attn.sum(3)
if self.args.use_sdp:
nll_duration = self.duration_predictor(
loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x,
x_mask,
attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
outputs["nll_duration"] = nll_duration
loss_duration = loss_duration/ torch.sum(x_mask)
else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor(
@ -445,7 +444,7 @@ class Vits(BaseTTS):
g=g.detach() if self.args.detach_dp_input and g is not None else g,
)
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration
outputs["loss_duration"] = loss_duration
# expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
@ -580,22 +579,17 @@ class Vits(BaseTTS):
scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"]
)
# handle the duration loss
if self.args.use_sdp:
loss_dict["nll_duration"] = outputs["nll_duration"]
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["loss_duration"]
elif optimizer_idx == 1:
# discriminator pass
outputs = {}
# compute scores and features
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache)
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(
self.y_disc_cache.detach(), self.wav_seg_disc_cache
)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
@ -686,7 +680,16 @@ class Vits(BaseTTS):
Returns:
List: optimizers.
"""
gen_parameters = [param for name, param in self.named_parameters() if not str.startswith(name, "disc.")]
gen_parameters = chain(
self.text_encoder.parameters(),
self.posterior_encoder.parameters(),
self.flow.parameters(),
self.duration_predictor.parameters(),
self.waveform_decoder.parameters(),
)
# add the speaker embedding layer
if hasattr(self, "emb_g"):
gen_parameters = chain(gen_parameters, self.emb_g)
optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
)
@ -710,9 +713,9 @@ class Vits(BaseTTS):
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2]
scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler0, scheduler1]
def get_criterion(self):
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in

View File

@ -29,7 +29,7 @@ class ConsoleLogger:
now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S")
def print_epoch_start(self, epoch, max_epoch, output_path = None):
def print_epoch_start(self, epoch, max_epoch, output_path=None):
print(
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
flush=True,