Fix stopnet training

This commit is contained in:
Eren Gölge 2021-07-20 17:34:42 +02:00
parent d7a9965389
commit fc0c4600bd
6 changed files with 28 additions and 12 deletions

View File

@ -246,9 +246,9 @@ class Huber(nn.Module):
class TacotronLoss(torch.nn.Module):
"""Collection of Tacotron set-up based on provided config."""
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
def __init__(self, c, ga_sigma=0.4):
super().__init__()
self.stopnet_pos_weight = stopnet_pos_weight
self.stopnet_pos_weight = c.stopnet_pos_weight
self.ga_alpha = c.ga_alpha
self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha
self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha
@ -274,7 +274,7 @@ class TacotronLoss(torch.nn.Module):
self.criterion_ssim = SSIMLoss()
# stopnet loss
# pylint: disable=not-callable
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None
def forward(
self,
@ -284,6 +284,7 @@ class TacotronLoss(torch.nn.Module):
linear_input,
stopnet_output,
stopnet_target,
stop_target_length,
output_lens,
decoder_b_output,
alignments,
@ -315,12 +316,12 @@ class TacotronLoss(torch.nn.Module):
return_dict["decoder_loss"] = decoder_loss
return_dict["postnet_loss"] = postnet_loss
# stopnet loss
stop_loss = (
self.criterion_st(stopnet_output, stopnet_target, output_lens) if self.config.stopnet else torch.zeros(1)
self.criterion_st(stopnet_output, stopnet_target, stop_target_length)
if self.config.stopnet
else torch.zeros(1)
)
if not self.config.separate_stopnet and self.config.stopnet:
loss += stop_loss
loss += stop_loss
return_dict["stopnet_loss"] = stop_loss
# backward decoder loss (if enabled)

View File

@ -119,9 +119,10 @@ class BaseTTS(BaseModel):
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, : text_lengths[idx]] = dur
# set stop targets view, we predict a single stop token per iteration.
# set stop targets wrt reduction factor
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
return {
"text_input": text_input,
@ -131,6 +132,7 @@ class BaseTTS(BaseModel):
"mel_lengths": mel_lengths,
"linear_input": linear_input,
"stop_targets": stop_targets,
"stop_target_lengths": stop_target_lengths,
"attn_mask": attn_mask,
"durations": durations,
"speaker_ids": speaker_ids,

View File

@ -219,6 +219,7 @@ class Tacotron(BaseTacotron):
mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
stop_targets = batch["stop_targets"]
stop_target_lengths = batch["stop_target_lengths"]
speaker_ids = batch["speaker_ids"]
d_vectors = batch["d_vectors"]
@ -250,6 +251,7 @@ class Tacotron(BaseTacotron):
linear_input,
outputs["stop_tokens"],
stop_targets,
stop_target_lengths,
mel_lengths,
outputs["decoder_outputs_backward"],
outputs["alignments"],

View File

@ -224,6 +224,7 @@ class Tacotron2(BaseTacotron):
mel_lengths = batch["mel_lengths"]
linear_input = batch["linear_input"]
stop_targets = batch["stop_targets"]
stop_target_lengths = batch["stop_target_lengths"]
speaker_ids = batch["speaker_ids"]
d_vectors = batch["d_vectors"]
@ -255,6 +256,7 @@ class Tacotron2(BaseTacotron):
linear_input,
outputs["stop_tokens"],
stop_targets,
stop_target_lengths,
mel_lengths,
outputs["decoder_outputs_backward"],
outputs["alignments"],

View File

@ -27,10 +27,19 @@ def prepare_tensor(inputs, out_steps):
return np.stack([_pad_tensor(x, pad_len) for x in inputs])
def _pad_stop_target(x, length):
_pad = 0.0
def _pad_stop_target(x: np.ndarray, length: int, pad_val=1) -> np.ndarray:
"""Pad stop target array.
Args:
x (np.ndarray): Stop target array.
length (int): Length after padding.
pad_val (int, optional): Padding value. Defaults to 1.
Returns:
np.ndarray: Padded stop target array.
"""
assert x.ndim == 1
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad)
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad_val)
def prepare_stop_target(inputs, out_steps):

View File

@ -207,7 +207,7 @@ class TestTTSDataset(unittest.TestCase):
assert linear_input[1 - idx, -1].sum() == 0
assert mel_input[1 - idx, -1].sum() == 0
assert stop_target[1, mel_lengths[1] - 1] == 1
assert stop_target[1, mel_lengths[1] :].sum() == 0
assert stop_target[1, mel_lengths[1] :].sum() == stop_target.shape[1] - mel_lengths[1]
assert len(mel_lengths.shape) == 1
# check batch zero-frame conditions (zero-frame disabled)