mirror of https://github.com/coqui-ai/TTS.git
multiscale subband stft loss and missing tb_logger for eval stats
This commit is contained in:
parent
d425daad5b
commit
3dff81e7e2
|
@ -49,11 +49,13 @@
|
|||
|
||||
// LOSS PARAMETERS
|
||||
"use_stft_loss": true,
|
||||
"use_subband_stft_loss": true,
|
||||
"use_mse_gan_loss": true,
|
||||
"use_hinge_gan_loss": false,
|
||||
"use_feat_match_loss": false, // use only with melgan discriminators
|
||||
|
||||
"stft_loss_weight": 1,
|
||||
"subband_stft_loss_weight": 1,
|
||||
"mse_gan_loss_weight": 2.5,
|
||||
"hinge_gan_loss_weight": 2.5,
|
||||
"feat_match_loss_weight": 10.0,
|
||||
|
@ -63,6 +65,12 @@
|
|||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
},
|
||||
"subband_stft_loss_params":{
|
||||
"n_ffts": [384, 683, 171],
|
||||
"hop_lengths": [30, 60, 10],
|
||||
"win_lengths": [150, 300, 60]
|
||||
},
|
||||
|
||||
"target_loss": "avg_G_loss", // loss value to pick the best model
|
||||
|
||||
// DISCRIMINATOR
|
||||
|
|
|
@ -49,7 +49,6 @@ class STFTLoss(nn.Module):
|
|||
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||
def __init__(self,
|
||||
n_ffts=[1024, 2048, 512],
|
||||
|
@ -73,6 +72,13 @@ class MultiScaleSTFTLoss(torch.nn.Module):
|
|||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||
def forward(self, y_hat, y):
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y = y.view(-1, 1, y.shape[2])
|
||||
return super().forward(y_hat.squeeze(1), y.squeeze(1))
|
||||
|
||||
|
||||
class MSEGLoss(nn.Module):
|
||||
""" Mean Squared Generator Loss """
|
||||
def __init__(self,):
|
||||
|
@ -145,17 +151,21 @@ class GeneratorLoss(nn.Module):
|
|||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_stft_loss = C.use_stft_loss
|
||||
self.use_subband_stft_loss = C.use_subband_stft_loss
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
self.use_feat_match_loss = C.use_feat_match_loss
|
||||
|
||||
self.stft_loss_weight = C.stft_loss_weight
|
||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight
|
||||
self.mse_gan_loss_weight = C.mse_gan_loss_weight
|
||||
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight
|
||||
|
||||
if C.use_stft_loss:
|
||||
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
||||
if C.use_subband_stft_loss:
|
||||
self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
|
||||
if C.use_mse_gan_loss:
|
||||
self.mse_loss = MSEGLoss()
|
||||
if C.use_hinge_gan_loss:
|
||||
|
@ -163,7 +173,7 @@ class GeneratorLoss(nn.Module):
|
|||
if C.use_feat_match_loss:
|
||||
self.feat_match_loss = MelganFeatureLoss()
|
||||
|
||||
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None):
|
||||
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
|
||||
|
@ -174,6 +184,13 @@ class GeneratorLoss(nn.Module):
|
|||
return_dict['G_stft_loss_sc'] = stft_loss_sc
|
||||
loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
||||
|
||||
# subband STFT Loss
|
||||
if self.use_subband_stft_loss:
|
||||
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
||||
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
|
||||
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
|
||||
loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
||||
|
||||
# Fake Losses
|
||||
if self.use_mse_gan_loss and scores_fake is not None:
|
||||
mse_fake_loss = 0
|
||||
|
|
|
@ -122,30 +122,27 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
# generator pass
|
||||
optimizer_G.zero_grad()
|
||||
y_hat = model_G(c_G)
|
||||
|
||||
in_real_D = y_hat
|
||||
in_fake_D = y_G
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
in_real_D = y_G
|
||||
in_fake_D = model_G.pqmf_synthesis(y_hat)
|
||||
y_G = model_G.pqmf_analysis(y_G)
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y_G = y_G.view(-1, 1, y_G.shape[2])
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(in_fake_D, c_G)
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(in_fake_D)
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(in_real_D)
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
|
@ -162,7 +159,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
feats_real)
|
||||
feats_real, y_hat_sub, y_G_sub)
|
||||
loss_G = loss_G_dict['G_loss']
|
||||
|
||||
# optimizer generator
|
||||
|
@ -272,12 +269,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
model_losses=loss_dict)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(in_fake_D, in_real_D, ap, global_step,
|
||||
figures = plot_results(y_hat, y_G, ap, global_step,
|
||||
'train')
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = in_fake_D[0].squeeze(0).detach().cpu().numpy()
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'train/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
|
@ -320,23 +317,20 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
|
||||
# generator pass
|
||||
y_hat = model_G(c_G)
|
||||
|
||||
in_real_D = y_hat
|
||||
in_fake_D = y_G
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
in_real_D = y_G
|
||||
in_fake_D = model_G.pqmf_synthesis(y_hat)
|
||||
y_G = model_G.pqmf_analysis(y_G)
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y_G = y_G.view(-1, 1, y_G.shape[2])
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
D_out_fake = model_D(in_fake_D)
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(in_real_D)
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
|
@ -350,7 +344,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
feats_real)
|
||||
feats_real, y_hat_sub, y_G_sub)
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
|
@ -364,7 +358,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
for key, value in loss_G_dict.items():
|
||||
update_eval_values['avg_' + key] = value.item()
|
||||
update_eval_values['avg_loader_time'] = loader_time
|
||||
update_eval_values['avg_step_time'] = step_time
|
||||
update_eval_values['avgP_step_time'] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
|
@ -372,18 +366,19 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(in_fake_D, in_real_D, ap, global_step, 'eval')
|
||||
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = in_fake_D[0].squeeze(0).detach().cpu().numpy()
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
|
||||
# synthesize a full voice
|
||||
data_loader.return_segments = False
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
tb_logger.tb_evals_stats(global_step, keep_avg.avg_values)
|
||||
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue