mirror of https://github.com/coqui-ai/TTS.git
Disable autcast for criterions
This commit is contained in:
parent
98a7271ce8
commit
59d52a4cd8
|
@ -1,9 +1,10 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
|
@ -12,7 +13,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -30,7 +30,7 @@ class AlignmentEncoder(torch.nn.Module):
|
|||
self.softmax = torch.nn.Softmax(dim=3)
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
|
||||
self.key_proj = nn.Sequential(
|
||||
self.key_layer = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_key_channels,
|
||||
in_key_channels * 2,
|
||||
|
@ -42,7 +42,7 @@ class AlignmentEncoder(torch.nn.Module):
|
|||
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
|
||||
)
|
||||
|
||||
self.query_proj = nn.Sequential(
|
||||
self.query_layer = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_query_channels,
|
||||
in_query_channels * 2,
|
||||
|
@ -58,33 +58,26 @@ class AlignmentEncoder(torch.nn.Module):
|
|||
|
||||
def forward(
|
||||
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
|
||||
):
|
||||
) -> Tuple[torch.tensor, torch.tensor]:
|
||||
"""Forward pass of the aligner encoder.
|
||||
Shapes:
|
||||
- queries: :math:`(B, C, T_de)`
|
||||
- keys: :math:`(B, C_emb, T_en)`
|
||||
- mask: :math:`(B, T_de)`
|
||||
- queries: :math:`[B, C, T_de]`
|
||||
- keys: :math:`[B, C_emb, T_en]`
|
||||
- mask: :math:`[B, T_de]`
|
||||
Output:
|
||||
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1.
|
||||
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
|
||||
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
|
||||
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
|
||||
"""
|
||||
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
|
||||
queries_enc = self.query_proj(queries)
|
||||
|
||||
# Simplistic Gaussian Isotopic Attention
|
||||
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
|
||||
attn = -self.temperature * attn.sum(1, keepdim=True)
|
||||
|
||||
key_out = self.key_layer(keys)
|
||||
query_out = self.query_layer(queries)
|
||||
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
|
||||
attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True)
|
||||
if attn_prior is not None:
|
||||
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)
|
||||
|
||||
attn_logprob = attn.clone()
|
||||
|
||||
attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8)
|
||||
if mask is not None:
|
||||
attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
|
||||
|
||||
attn = self.softmax(attn) # softmax along T2
|
||||
return attn, attn_logprob
|
||||
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
|
||||
attn = self.softmax(attn_logp)
|
||||
return attn, attn_logp
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -414,23 +407,24 @@ class FastPitch(BaseTTS):
|
|||
if self.use_aligner:
|
||||
durations = outputs["o_alignment_dur"]
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
durations,
|
||||
outputs["pitch"],
|
||||
outputs["pitch_gt"],
|
||||
text_lengths,
|
||||
outputs["alignment_logprob"],
|
||||
)
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
durations,
|
||||
outputs["pitch"],
|
||||
outputs["pitch_gt"],
|
||||
text_lengths,
|
||||
outputs["alignment_logprob"],
|
||||
)
|
||||
|
||||
# compute duration error
|
||||
durations_pred = outputs["durations"]
|
||||
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
|
||||
loss_dict["duration_error"] = duration_error
|
||||
# compute duration error
|
||||
durations_pred = outputs["durations"]
|
||||
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
|
||||
loss_dict["duration_error"] = duration_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||
|
|
Loading…
Reference in New Issue