mirror of https://github.com/coqui-ai/TTS.git
Fix stopnet training
This commit is contained in:
parent
d7a9965389
commit
fc0c4600bd
|
@ -246,9 +246,9 @@ class Huber(nn.Module):
|
|||
class TacotronLoss(torch.nn.Module):
|
||||
"""Collection of Tacotron set-up based on provided config."""
|
||||
|
||||
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
|
||||
def __init__(self, c, ga_sigma=0.4):
|
||||
super().__init__()
|
||||
self.stopnet_pos_weight = stopnet_pos_weight
|
||||
self.stopnet_pos_weight = c.stopnet_pos_weight
|
||||
self.ga_alpha = c.ga_alpha
|
||||
self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha
|
||||
self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha
|
||||
|
@ -274,7 +274,7 @@ class TacotronLoss(torch.nn.Module):
|
|||
self.criterion_ssim = SSIMLoss()
|
||||
# stopnet loss
|
||||
# pylint: disable=not-callable
|
||||
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
|
||||
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -284,6 +284,7 @@ class TacotronLoss(torch.nn.Module):
|
|||
linear_input,
|
||||
stopnet_output,
|
||||
stopnet_target,
|
||||
stop_target_length,
|
||||
output_lens,
|
||||
decoder_b_output,
|
||||
alignments,
|
||||
|
@ -315,12 +316,12 @@ class TacotronLoss(torch.nn.Module):
|
|||
return_dict["decoder_loss"] = decoder_loss
|
||||
return_dict["postnet_loss"] = postnet_loss
|
||||
|
||||
# stopnet loss
|
||||
stop_loss = (
|
||||
self.criterion_st(stopnet_output, stopnet_target, output_lens) if self.config.stopnet else torch.zeros(1)
|
||||
self.criterion_st(stopnet_output, stopnet_target, stop_target_length)
|
||||
if self.config.stopnet
|
||||
else torch.zeros(1)
|
||||
)
|
||||
if not self.config.separate_stopnet and self.config.stopnet:
|
||||
loss += stop_loss
|
||||
loss += stop_loss
|
||||
return_dict["stopnet_loss"] = stop_loss
|
||||
|
||||
# backward decoder loss (if enabled)
|
||||
|
|
|
@ -119,9 +119,10 @@ class BaseTTS(BaseModel):
|
|||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
|
||||
# set stop targets view, we predict a single stop token per iteration.
|
||||
# set stop targets wrt reduction factor
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
|
||||
|
||||
return {
|
||||
"text_input": text_input,
|
||||
|
@ -131,6 +132,7 @@ class BaseTTS(BaseModel):
|
|||
"mel_lengths": mel_lengths,
|
||||
"linear_input": linear_input,
|
||||
"stop_targets": stop_targets,
|
||||
"stop_target_lengths": stop_target_lengths,
|
||||
"attn_mask": attn_mask,
|
||||
"durations": durations,
|
||||
"speaker_ids": speaker_ids,
|
||||
|
|
|
@ -219,6 +219,7 @@ class Tacotron(BaseTacotron):
|
|||
mel_lengths = batch["mel_lengths"]
|
||||
linear_input = batch["linear_input"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
stop_target_lengths = batch["stop_target_lengths"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
|
||||
|
@ -250,6 +251,7 @@ class Tacotron(BaseTacotron):
|
|||
linear_input,
|
||||
outputs["stop_tokens"],
|
||||
stop_targets,
|
||||
stop_target_lengths,
|
||||
mel_lengths,
|
||||
outputs["decoder_outputs_backward"],
|
||||
outputs["alignments"],
|
||||
|
|
|
@ -224,6 +224,7 @@ class Tacotron2(BaseTacotron):
|
|||
mel_lengths = batch["mel_lengths"]
|
||||
linear_input = batch["linear_input"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
stop_target_lengths = batch["stop_target_lengths"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
|
||||
|
@ -255,6 +256,7 @@ class Tacotron2(BaseTacotron):
|
|||
linear_input,
|
||||
outputs["stop_tokens"],
|
||||
stop_targets,
|
||||
stop_target_lengths,
|
||||
mel_lengths,
|
||||
outputs["decoder_outputs_backward"],
|
||||
outputs["alignments"],
|
||||
|
|
|
@ -27,10 +27,19 @@ def prepare_tensor(inputs, out_steps):
|
|||
return np.stack([_pad_tensor(x, pad_len) for x in inputs])
|
||||
|
||||
|
||||
def _pad_stop_target(x, length):
|
||||
_pad = 0.0
|
||||
def _pad_stop_target(x: np.ndarray, length: int, pad_val=1) -> np.ndarray:
|
||||
"""Pad stop target array.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Stop target array.
|
||||
length (int): Length after padding.
|
||||
pad_val (int, optional): Padding value. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Padded stop target array.
|
||||
"""
|
||||
assert x.ndim == 1
|
||||
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad)
|
||||
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad_val)
|
||||
|
||||
|
||||
def prepare_stop_target(inputs, out_steps):
|
||||
|
|
|
@ -207,7 +207,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
assert linear_input[1 - idx, -1].sum() == 0
|
||||
assert mel_input[1 - idx, -1].sum() == 0
|
||||
assert stop_target[1, mel_lengths[1] - 1] == 1
|
||||
assert stop_target[1, mel_lengths[1] :].sum() == 0
|
||||
assert stop_target[1, mel_lengths[1] :].sum() == stop_target.shape[1] - mel_lengths[1]
|
||||
assert len(mel_lengths.shape) == 1
|
||||
|
||||
# check batch zero-frame conditions (zero-frame disabled)
|
||||
|
|
Loading…
Reference in New Issue