Move duration_loss inside VitsGeneratorLoss

This commit is contained in:
Eren Gölge 2021-08-27 07:07:07 +00:00
parent 49e1181ea4
commit 2620f62ea8
5 changed files with 34 additions and 24 deletions

View File

@ -423,7 +423,7 @@ class Trainer:
optimizer: Target optimizer. optimizer: Target optimizer.
""" """
for group in optimizer.param_groups: for group in optimizer.param_groups:
for p in group['params']: for p in group["params"]:
yield p yield p
@staticmethod @staticmethod
@ -528,7 +528,9 @@ class Trainer:
scaler.scale(loss_dict["loss"]).backward() scaler.scale(loss_dict["loss"]).backward()
if grad_clip > 0: if grad_clip > 0:
scaler.unscale_(optimizer) scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip, error_if_nonfinite=False) grad_norm = torch.nn.utils.clip_grad_norm_(
self.master_params(optimizer), grad_clip, error_if_nonfinite=False
)
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN # pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
if torch.isnan(grad_norm) or torch.isinf(grad_norm): if torch.isnan(grad_norm) or torch.isinf(grad_norm):
grad_norm = 0 grad_norm = 0

View File

@ -113,6 +113,7 @@ class VitsConfig(BaseTTSConfig):
gen_loss_alpha: float = 1.0 gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0 feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0 mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
# data loader params # data loader params
return_wav: bool = True return_wav: bool = True

View File

@ -524,6 +524,7 @@ class VitsGeneratorLoss(nn.Module):
self.kl_loss_alpha = c.kl_loss_alpha self.kl_loss_alpha = c.kl_loss_alpha
self.gen_loss_alpha = c.gen_loss_alpha self.gen_loss_alpha = c.gen_loss_alpha
self.feat_loss_alpha = c.feat_loss_alpha self.feat_loss_alpha = c.feat_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.mel_loss_alpha = c.mel_loss_alpha self.mel_loss_alpha = c.mel_loss_alpha
self.stft = TorchSTFT( self.stft = TorchSTFT(
c.audio.fft_size, c.audio.fft_size,
@ -590,6 +591,7 @@ class VitsGeneratorLoss(nn.Module):
scores_disc_fake, scores_disc_fake,
feats_disc_fake, feats_disc_fake,
feats_disc_real, feats_disc_real,
loss_duration,
): ):
""" """
Shapes: Shapes:
@ -615,12 +617,14 @@ class VitsGeneratorLoss(nn.Module):
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
# pass losses to the dict # pass losses to the dict
return_dict["loss_gen"] = loss_gen return_dict["loss_gen"] = loss_gen
return_dict["loss_kl"] = loss_kl return_dict["loss_kl"] = loss_kl
return_dict["loss_feat"] = loss_feat return_dict["loss_feat"] = loss_feat
return_dict["loss_mel"] = loss_mel return_dict["loss_mel"] = loss_mel
return_dict["loss_duration"] = loss_duration
return_dict["loss"] = loss return_dict["loss"] = loss
return return_dict return return_dict

View File

@ -1,4 +1,6 @@
import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
@ -11,8 +13,6 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager from TTS.tts.utils.speakers import get_speaker_manager
@ -195,7 +195,7 @@ class VitsArgs(Coqpit):
inference_noise_scale: float = 0.667 inference_noise_scale: float = 0.667
length_scale: int = 1 length_scale: int = 1
noise_scale_dp: float = 1.0 noise_scale_dp: float = 1.0
inference_noise_scale_dp: float = 1. inference_noise_scale_dp: float = 1.0
max_inference_len: int = None max_inference_len: int = None
init_discriminator: bool = True init_discriminator: bool = True
use_spectral_norm_disriminator: bool = False use_spectral_norm_disriminator: bool = False
@ -429,14 +429,13 @@ class Vits(BaseTTS):
# duration predictor # duration predictor
attn_durations = attn.sum(3) attn_durations = attn.sum(3)
if self.args.use_sdp: if self.args.use_sdp:
nll_duration = self.duration_predictor( loss_duration = self.duration_predictor(
x.detach() if self.args.detach_dp_input else x, x.detach() if self.args.detach_dp_input else x,
x_mask, x_mask,
attn_durations, attn_durations,
g=g.detach() if self.args.detach_dp_input and g is not None else g, g=g.detach() if self.args.detach_dp_input and g is not None else g,
) )
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask)) loss_duration = loss_duration/ torch.sum(x_mask)
outputs["nll_duration"] = nll_duration
else: else:
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
log_durations = self.duration_predictor( log_durations = self.duration_predictor(
@ -445,7 +444,7 @@ class Vits(BaseTTS):
g=g.detach() if self.args.detach_dp_input and g is not None else g, g=g.detach() if self.args.detach_dp_input and g is not None else g,
) )
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask) loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
outputs["loss_duration"] = loss_duration outputs["loss_duration"] = loss_duration
# expand prior # expand prior
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p]) m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
@ -580,22 +579,17 @@ class Vits(BaseTTS):
scores_disc_fake=outputs["scores_disc_fake"], scores_disc_fake=outputs["scores_disc_fake"],
feats_disc_fake=outputs["feats_disc_fake"], feats_disc_fake=outputs["feats_disc_fake"],
feats_disc_real=outputs["feats_disc_real"], feats_disc_real=outputs["feats_disc_real"],
loss_duration=outputs["loss_duration"]
) )
# handle the duration loss
if self.args.use_sdp:
loss_dict["nll_duration"] = outputs["nll_duration"]
loss_dict["loss"] += outputs["nll_duration"]
else:
loss_dict["loss_duration"] = outputs["loss_duration"]
loss_dict["loss"] += outputs["loss_duration"]
elif optimizer_idx == 1: elif optimizer_idx == 1:
# discriminator pass # discriminator pass
outputs = {} outputs = {}
# compute scores and features # compute scores and features
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache) outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(
self.y_disc_cache.detach(), self.wav_seg_disc_cache
)
# compute loss # compute loss
with autocast(enabled=False): # use float32 for the criterion with autocast(enabled=False): # use float32 for the criterion
@ -686,7 +680,16 @@ class Vits(BaseTTS):
Returns: Returns:
List: optimizers. List: optimizers.
""" """
gen_parameters = [param for name, param in self.named_parameters() if not str.startswith(name, "disc.")] gen_parameters = chain(
self.text_encoder.parameters(),
self.posterior_encoder.parameters(),
self.flow.parameters(),
self.duration_predictor.parameters(),
self.waveform_decoder.parameters(),
)
# add the speaker embedding layer
if hasattr(self, "emb_g"):
gen_parameters = chain(gen_parameters, self.emb_g)
optimizer0 = get_optimizer( optimizer0 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
) )
@ -710,9 +713,9 @@ class Vits(BaseTTS):
Returns: Returns:
List: Schedulers, one for each optimizer. List: Schedulers, one for each optimizer.
""" """
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2] return [scheduler0, scheduler1]
def get_criterion(self): def get_criterion(self):
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in """Get criterions for each optimizer. The index in the output list matches the optimizer idx used in

View File

@ -29,7 +29,7 @@ class ConsoleLogger:
now = datetime.datetime.now() now = datetime.datetime.now()
return now.strftime("%Y-%m-%d %H:%M:%S") return now.strftime("%Y-%m-%d %H:%M:%S")
def print_epoch_start(self, epoch, max_epoch, output_path = None): def print_epoch_start(self, epoch, max_epoch, output_path=None):
print( print(
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
flush=True, flush=True,