mirror of https://github.com/coqui-ai/TTS.git
Fix stopnet training
This commit is contained in:
parent
d7a9965389
commit
fc0c4600bd
|
@ -246,9 +246,9 @@ class Huber(nn.Module):
|
||||||
class TacotronLoss(torch.nn.Module):
|
class TacotronLoss(torch.nn.Module):
|
||||||
"""Collection of Tacotron set-up based on provided config."""
|
"""Collection of Tacotron set-up based on provided config."""
|
||||||
|
|
||||||
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
|
def __init__(self, c, ga_sigma=0.4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.stopnet_pos_weight = stopnet_pos_weight
|
self.stopnet_pos_weight = c.stopnet_pos_weight
|
||||||
self.ga_alpha = c.ga_alpha
|
self.ga_alpha = c.ga_alpha
|
||||||
self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha
|
self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha
|
||||||
self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha
|
self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha
|
||||||
|
@ -274,7 +274,7 @@ class TacotronLoss(torch.nn.Module):
|
||||||
self.criterion_ssim = SSIMLoss()
|
self.criterion_ssim = SSIMLoss()
|
||||||
# stopnet loss
|
# stopnet loss
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
|
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -284,6 +284,7 @@ class TacotronLoss(torch.nn.Module):
|
||||||
linear_input,
|
linear_input,
|
||||||
stopnet_output,
|
stopnet_output,
|
||||||
stopnet_target,
|
stopnet_target,
|
||||||
|
stop_target_length,
|
||||||
output_lens,
|
output_lens,
|
||||||
decoder_b_output,
|
decoder_b_output,
|
||||||
alignments,
|
alignments,
|
||||||
|
@ -315,12 +316,12 @@ class TacotronLoss(torch.nn.Module):
|
||||||
return_dict["decoder_loss"] = decoder_loss
|
return_dict["decoder_loss"] = decoder_loss
|
||||||
return_dict["postnet_loss"] = postnet_loss
|
return_dict["postnet_loss"] = postnet_loss
|
||||||
|
|
||||||
# stopnet loss
|
|
||||||
stop_loss = (
|
stop_loss = (
|
||||||
self.criterion_st(stopnet_output, stopnet_target, output_lens) if self.config.stopnet else torch.zeros(1)
|
self.criterion_st(stopnet_output, stopnet_target, stop_target_length)
|
||||||
|
if self.config.stopnet
|
||||||
|
else torch.zeros(1)
|
||||||
)
|
)
|
||||||
if not self.config.separate_stopnet and self.config.stopnet:
|
loss += stop_loss
|
||||||
loss += stop_loss
|
|
||||||
return_dict["stopnet_loss"] = stop_loss
|
return_dict["stopnet_loss"] = stop_loss
|
||||||
|
|
||||||
# backward decoder loss (if enabled)
|
# backward decoder loss (if enabled)
|
||||||
|
|
|
@ -119,9 +119,10 @@ class BaseTTS(BaseModel):
|
||||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||||
durations[idx, : text_lengths[idx]] = dur
|
durations[idx, : text_lengths[idx]] = dur
|
||||||
|
|
||||||
# set stop targets view, we predict a single stop token per iteration.
|
# set stop targets wrt reduction factor
|
||||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||||
|
stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"text_input": text_input,
|
"text_input": text_input,
|
||||||
|
@ -131,6 +132,7 @@ class BaseTTS(BaseModel):
|
||||||
"mel_lengths": mel_lengths,
|
"mel_lengths": mel_lengths,
|
||||||
"linear_input": linear_input,
|
"linear_input": linear_input,
|
||||||
"stop_targets": stop_targets,
|
"stop_targets": stop_targets,
|
||||||
|
"stop_target_lengths": stop_target_lengths,
|
||||||
"attn_mask": attn_mask,
|
"attn_mask": attn_mask,
|
||||||
"durations": durations,
|
"durations": durations,
|
||||||
"speaker_ids": speaker_ids,
|
"speaker_ids": speaker_ids,
|
||||||
|
|
|
@ -219,6 +219,7 @@ class Tacotron(BaseTacotron):
|
||||||
mel_lengths = batch["mel_lengths"]
|
mel_lengths = batch["mel_lengths"]
|
||||||
linear_input = batch["linear_input"]
|
linear_input = batch["linear_input"]
|
||||||
stop_targets = batch["stop_targets"]
|
stop_targets = batch["stop_targets"]
|
||||||
|
stop_target_lengths = batch["stop_target_lengths"]
|
||||||
speaker_ids = batch["speaker_ids"]
|
speaker_ids = batch["speaker_ids"]
|
||||||
d_vectors = batch["d_vectors"]
|
d_vectors = batch["d_vectors"]
|
||||||
|
|
||||||
|
@ -250,6 +251,7 @@ class Tacotron(BaseTacotron):
|
||||||
linear_input,
|
linear_input,
|
||||||
outputs["stop_tokens"],
|
outputs["stop_tokens"],
|
||||||
stop_targets,
|
stop_targets,
|
||||||
|
stop_target_lengths,
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
outputs["decoder_outputs_backward"],
|
outputs["decoder_outputs_backward"],
|
||||||
outputs["alignments"],
|
outputs["alignments"],
|
||||||
|
|
|
@ -224,6 +224,7 @@ class Tacotron2(BaseTacotron):
|
||||||
mel_lengths = batch["mel_lengths"]
|
mel_lengths = batch["mel_lengths"]
|
||||||
linear_input = batch["linear_input"]
|
linear_input = batch["linear_input"]
|
||||||
stop_targets = batch["stop_targets"]
|
stop_targets = batch["stop_targets"]
|
||||||
|
stop_target_lengths = batch["stop_target_lengths"]
|
||||||
speaker_ids = batch["speaker_ids"]
|
speaker_ids = batch["speaker_ids"]
|
||||||
d_vectors = batch["d_vectors"]
|
d_vectors = batch["d_vectors"]
|
||||||
|
|
||||||
|
@ -255,6 +256,7 @@ class Tacotron2(BaseTacotron):
|
||||||
linear_input,
|
linear_input,
|
||||||
outputs["stop_tokens"],
|
outputs["stop_tokens"],
|
||||||
stop_targets,
|
stop_targets,
|
||||||
|
stop_target_lengths,
|
||||||
mel_lengths,
|
mel_lengths,
|
||||||
outputs["decoder_outputs_backward"],
|
outputs["decoder_outputs_backward"],
|
||||||
outputs["alignments"],
|
outputs["alignments"],
|
||||||
|
|
|
@ -27,10 +27,19 @@ def prepare_tensor(inputs, out_steps):
|
||||||
return np.stack([_pad_tensor(x, pad_len) for x in inputs])
|
return np.stack([_pad_tensor(x, pad_len) for x in inputs])
|
||||||
|
|
||||||
|
|
||||||
def _pad_stop_target(x, length):
|
def _pad_stop_target(x: np.ndarray, length: int, pad_val=1) -> np.ndarray:
|
||||||
_pad = 0.0
|
"""Pad stop target array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (np.ndarray): Stop target array.
|
||||||
|
length (int): Length after padding.
|
||||||
|
pad_val (int, optional): Padding value. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Padded stop target array.
|
||||||
|
"""
|
||||||
assert x.ndim == 1
|
assert x.ndim == 1
|
||||||
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad)
|
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad_val)
|
||||||
|
|
||||||
|
|
||||||
def prepare_stop_target(inputs, out_steps):
|
def prepare_stop_target(inputs, out_steps):
|
||||||
|
|
|
@ -207,7 +207,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
assert linear_input[1 - idx, -1].sum() == 0
|
assert linear_input[1 - idx, -1].sum() == 0
|
||||||
assert mel_input[1 - idx, -1].sum() == 0
|
assert mel_input[1 - idx, -1].sum() == 0
|
||||||
assert stop_target[1, mel_lengths[1] - 1] == 1
|
assert stop_target[1, mel_lengths[1] - 1] == 1
|
||||||
assert stop_target[1, mel_lengths[1] :].sum() == 0
|
assert stop_target[1, mel_lengths[1] :].sum() == stop_target.shape[1] - mel_lengths[1]
|
||||||
assert len(mel_lengths.shape) == 1
|
assert len(mel_lengths.shape) == 1
|
||||||
|
|
||||||
# check batch zero-frame conditions (zero-frame disabled)
|
# check batch zero-frame conditions (zero-frame disabled)
|
||||||
|
|
Loading…
Reference in New Issue