multiscale subband stft loss and missing tb_logger for eval stats

This commit is contained in:
erogol 2020-06-02 17:07:15 +02:00
parent d425daad5b
commit 3dff81e7e2
3 changed files with 51 additions and 31 deletions

View File

@ -49,11 +49,13 @@
// LOSS PARAMETERS
"use_stft_loss": true,
"use_subband_stft_loss": true,
"use_mse_gan_loss": true,
"use_hinge_gan_loss": false,
"use_feat_match_loss": false, // use only with melgan discriminators
"stft_loss_weight": 1,
"subband_stft_loss_weight": 1,
"mse_gan_loss_weight": 2.5,
"hinge_gan_loss_weight": 2.5,
"feat_match_loss_weight": 10.0,
@ -63,6 +65,12 @@
"hop_lengths": [120, 240, 50],
"win_lengths": [600, 1200, 240]
},
"subband_stft_loss_params":{
"n_ffts": [384, 683, 171],
"hop_lengths": [30, 60, 10],
"win_lengths": [150, 300, 60]
},
"target_loss": "avg_G_loss", // loss value to pick the best model
// DISCRIMINATOR

View File

@ -49,7 +49,6 @@ class STFTLoss(nn.Module):
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
return loss_mag, loss_sc
class MultiScaleSTFTLoss(torch.nn.Module):
def __init__(self,
n_ffts=[1024, 2048, 512],
@ -73,6 +72,13 @@ class MultiScaleSTFTLoss(torch.nn.Module):
return loss_mag, loss_sc
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
def forward(self, y_hat, y):
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
y = y.view(-1, 1, y.shape[2])
return super().forward(y_hat.squeeze(1), y.squeeze(1))
class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """
def __init__(self,):
@ -145,17 +151,21 @@ class GeneratorLoss(nn.Module):
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_stft_loss = C.use_stft_loss
self.use_subband_stft_loss = C.use_subband_stft_loss
self.use_mse_gan_loss = C.use_mse_gan_loss
self.use_hinge_gan_loss = C.use_hinge_gan_loss
self.use_feat_match_loss = C.use_feat_match_loss
self.stft_loss_weight = C.stft_loss_weight
self.subband_stft_loss_weight = C.subband_stft_loss_weight
self.mse_gan_loss_weight = C.mse_gan_loss_weight
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
self.feat_match_loss_weight = C.feat_match_loss_weight
if C.use_stft_loss:
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
if C.use_subband_stft_loss:
self.subband_stft_loss = MultiScaleSubbandSTFTLoss(**C.subband_stft_loss_params)
if C.use_mse_gan_loss:
self.mse_loss = MSEGLoss()
if C.use_hinge_gan_loss:
@ -163,7 +173,7 @@ class GeneratorLoss(nn.Module):
if C.use_feat_match_loss:
self.feat_match_loss = MelganFeatureLoss()
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None):
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
loss = 0
return_dict = {}
@ -174,6 +184,13 @@ class GeneratorLoss(nn.Module):
return_dict['G_stft_loss_sc'] = stft_loss_sc
loss += self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
# subband STFT Loss
if self.use_subband_stft_loss:
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
loss += self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
# Fake Losses
if self.use_mse_gan_loss and scores_fake is not None:
mse_fake_loss = 0

View File

@ -122,30 +122,27 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# generator pass
optimizer_G.zero_grad()
y_hat = model_G(c_G)
in_real_D = y_hat
in_fake_D = y_G
y_hat_sub = None
y_G_sub = None
# PQMF formatting
if y_hat.shape[1] > 1:
in_real_D = y_G
in_fake_D = model_G.pqmf_synthesis(y_hat)
y_G = model_G.pqmf_analysis(y_G)
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
y_G = y_G.view(-1, 1, y_G.shape[2])
y_hat_sub = y_hat
y_hat = model_G.pqmf_synthesis(y_hat)
y_G_sub = model_G.pqmf_analysis(y_G)
if global_step > c.steps_to_start_discriminator:
# run D with or without cond. features
if len(signature(model_D.forward).parameters) == 2:
D_out_fake = model_D(in_fake_D, c_G)
D_out_fake = model_D(y_hat, c_G)
else:
D_out_fake = model_D(in_fake_D)
D_out_fake = model_D(y_hat)
D_out_real = None
if c.use_feat_match_loss:
with torch.no_grad():
D_out_real = model_D(in_real_D)
D_out_real = model_D(y_G)
# format D outputs
if isinstance(D_out_fake, tuple):
@ -162,7 +159,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# compute losses
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
feats_real)
feats_real, y_hat_sub, y_G_sub)
loss_G = loss_G_dict['G_loss']
# optimizer generator
@ -272,12 +269,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
model_losses=loss_dict)
# compute spectrograms
figures = plot_results(in_fake_D, in_real_D, ap, global_step,
figures = plot_results(y_hat, y_G, ap, global_step,
'train')
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
sample_voice = in_fake_D[0].squeeze(0).detach().cpu().numpy()
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice},
c.audio["sample_rate"])
@ -320,23 +317,20 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
# generator pass
y_hat = model_G(c_G)
in_real_D = y_hat
in_fake_D = y_G
y_hat_sub = None
y_G_sub = None
# PQMF formatting
if y_hat.shape[1] > 1:
in_real_D = y_G
in_fake_D = model_G.pqmf_synthesis(y_hat)
y_G = model_G.pqmf_analysis(y_G)
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
y_G = y_G.view(-1, 1, y_G.shape[2])
y_hat_sub = y_hat
y_hat = model_G.pqmf_synthesis(y_hat)
y_G_sub = model_G.pqmf_analysis(y_G)
D_out_fake = model_D(in_fake_D)
D_out_fake = model_D(y_hat)
D_out_real = None
if c.use_feat_match_loss:
with torch.no_grad():
D_out_real = model_D(in_real_D)
D_out_real = model_D(y_G)
# format D outputs
if isinstance(D_out_fake, tuple):
@ -350,7 +344,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
# compute losses
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
feats_real)
feats_real, y_hat_sub, y_G_sub)
loss_dict = dict()
for key, value in loss_G_dict.items():
@ -364,7 +358,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
for key, value in loss_G_dict.items():
update_eval_values['avg_' + key] = value.item()
update_eval_values['avg_loader_time'] = loader_time
update_eval_values['avg_step_time'] = step_time
update_eval_values['avgP_step_time'] = step_time
keep_avg.update_values(update_eval_values)
# print eval stats
@ -372,18 +366,19 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
# compute spectrograms
figures = plot_results(in_fake_D, in_real_D, ap, global_step, 'eval')
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = in_fake_D[0].squeeze(0).detach().cpu().numpy()
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
# synthesize a full voice
data_loader.return_segments = False
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
tb_logger.tb_evals_stats(global_step, keep_avg.avg_values)
return keep_avg.avg_values