mirror of https://github.com/coqui-ai/TTS.git
Move duration_loss inside VitsGeneratorLoss
This commit is contained in:
parent
49e1181ea4
commit
2620f62ea8
|
@ -423,7 +423,7 @@ class Trainer:
|
|||
optimizer: Target optimizer.
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
yield p
|
||||
|
||||
@staticmethod
|
||||
|
@ -528,7 +528,9 @@ class Trainer:
|
|||
scaler.scale(loss_dict["loss"]).backward()
|
||||
if grad_clip > 0:
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.master_params(optimizer), grad_clip, error_if_nonfinite=False)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.master_params(optimizer), grad_clip, error_if_nonfinite=False
|
||||
)
|
||||
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||
grad_norm = 0
|
||||
|
|
|
@ -113,6 +113,7 @@ class VitsConfig(BaseTTSConfig):
|
|||
gen_loss_alpha: float = 1.0
|
||||
feat_loss_alpha: float = 1.0
|
||||
mel_loss_alpha: float = 45.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
|
||||
# data loader params
|
||||
return_wav: bool = True
|
||||
|
|
|
@ -524,6 +524,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
self.kl_loss_alpha = c.kl_loss_alpha
|
||||
self.gen_loss_alpha = c.gen_loss_alpha
|
||||
self.feat_loss_alpha = c.feat_loss_alpha
|
||||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.mel_loss_alpha = c.mel_loss_alpha
|
||||
self.stft = TorchSTFT(
|
||||
c.audio.fft_size,
|
||||
|
@ -590,6 +591,7 @@ class VitsGeneratorLoss(nn.Module):
|
|||
scores_disc_fake,
|
||||
feats_disc_fake,
|
||||
feats_disc_real,
|
||||
loss_duration,
|
||||
):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -615,12 +617,14 @@ class VitsGeneratorLoss(nn.Module):
|
|||
loss_gen = self.generator_loss(scores_disc_fake)[0] * self.gen_loss_alpha
|
||||
loss_kl = self.kl_loss(z_p, logs_q, m_p, logs_p, z_mask.unsqueeze(1)) * self.kl_loss_alpha
|
||||
loss_mel = torch.nn.functional.l1_loss(mel, mel_hat) * self.mel_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen
|
||||
loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha
|
||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||
# pass losses to the dict
|
||||
return_dict["loss_gen"] = loss_gen
|
||||
return_dict["loss_kl"] = loss_kl
|
||||
return_dict["loss_feat"] = loss_feat
|
||||
return_dict["loss_mel"] = loss_mel
|
||||
return_dict["loss_duration"] = loss_duration
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -11,8 +13,6 @@ from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
|||
from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
||||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
|
||||
# from TTS.tts.layers.vits.sdp import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
|
@ -195,7 +195,7 @@ class VitsArgs(Coqpit):
|
|||
inference_noise_scale: float = 0.667
|
||||
length_scale: int = 1
|
||||
noise_scale_dp: float = 1.0
|
||||
inference_noise_scale_dp: float = 1.
|
||||
inference_noise_scale_dp: float = 1.0
|
||||
max_inference_len: int = None
|
||||
init_discriminator: bool = True
|
||||
use_spectral_norm_disriminator: bool = False
|
||||
|
@ -429,14 +429,13 @@ class Vits(BaseTTS):
|
|||
# duration predictor
|
||||
attn_durations = attn.sum(3)
|
||||
if self.args.use_sdp:
|
||||
nll_duration = self.duration_predictor(
|
||||
loss_duration = self.duration_predictor(
|
||||
x.detach() if self.args.detach_dp_input else x,
|
||||
x_mask,
|
||||
attn_durations,
|
||||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
)
|
||||
nll_duration = torch.sum(nll_duration.float() / torch.sum(x_mask))
|
||||
outputs["nll_duration"] = nll_duration
|
||||
loss_duration = loss_duration/ torch.sum(x_mask)
|
||||
else:
|
||||
attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
|
||||
log_durations = self.duration_predictor(
|
||||
|
@ -445,7 +444,7 @@ class Vits(BaseTTS):
|
|||
g=g.detach() if self.args.detach_dp_input and g is not None else g,
|
||||
)
|
||||
loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
|
||||
outputs["loss_duration"] = loss_duration
|
||||
outputs["loss_duration"] = loss_duration
|
||||
|
||||
# expand prior
|
||||
m_p = torch.einsum("klmn, kjm -> kjn", [attn, m_p])
|
||||
|
@ -580,22 +579,17 @@ class Vits(BaseTTS):
|
|||
scores_disc_fake=outputs["scores_disc_fake"],
|
||||
feats_disc_fake=outputs["feats_disc_fake"],
|
||||
feats_disc_real=outputs["feats_disc_real"],
|
||||
loss_duration=outputs["loss_duration"]
|
||||
)
|
||||
|
||||
# handle the duration loss
|
||||
if self.args.use_sdp:
|
||||
loss_dict["nll_duration"] = outputs["nll_duration"]
|
||||
loss_dict["loss"] += outputs["nll_duration"]
|
||||
else:
|
||||
loss_dict["loss_duration"] = outputs["loss_duration"]
|
||||
loss_dict["loss"] += outputs["loss_duration"]
|
||||
|
||||
elif optimizer_idx == 1:
|
||||
# discriminator pass
|
||||
outputs = {}
|
||||
|
||||
# compute scores and features
|
||||
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(self.y_disc_cache.detach(), self.wav_seg_disc_cache)
|
||||
outputs["scores_disc_fake"], _, outputs["scores_disc_real"], _ = self.disc(
|
||||
self.y_disc_cache.detach(), self.wav_seg_disc_cache
|
||||
)
|
||||
|
||||
# compute loss
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
|
@ -686,7 +680,16 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
List: optimizers.
|
||||
"""
|
||||
gen_parameters = [param for name, param in self.named_parameters() if not str.startswith(name, "disc.")]
|
||||
gen_parameters = chain(
|
||||
self.text_encoder.parameters(),
|
||||
self.posterior_encoder.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.duration_predictor.parameters(),
|
||||
self.waveform_decoder.parameters(),
|
||||
)
|
||||
# add the speaker embedding layer
|
||||
if hasattr(self, "emb_g"):
|
||||
gen_parameters = chain(gen_parameters, self.emb_g)
|
||||
optimizer0 = get_optimizer(
|
||||
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, parameters=gen_parameters
|
||||
)
|
||||
|
@ -710,9 +713,9 @@ class Vits(BaseTTS):
|
|||
Returns:
|
||||
List: Schedulers, one for each optimizer.
|
||||
"""
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler1, scheduler2]
|
||||
scheduler0 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
|
||||
scheduler1 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
|
||||
return [scheduler0, scheduler1]
|
||||
|
||||
def get_criterion(self):
|
||||
"""Get criterions for each optimizer. The index in the output list matches the optimizer idx used in
|
||||
|
|
|
@ -29,7 +29,7 @@ class ConsoleLogger:
|
|||
now = datetime.datetime.now()
|
||||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def print_epoch_start(self, epoch, max_epoch, output_path = None):
|
||||
def print_epoch_start(self, epoch, max_epoch, output_path=None):
|
||||
print(
|
||||
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
|
||||
flush=True,
|
||||
|
|
Loading…
Reference in New Issue