Disable autcast for criterions

This commit is contained in:
Eren Gölge 2021-09-01 08:31:33 +00:00
parent 98a7271ce8
commit 59d52a4cd8
1 changed files with 35 additions and 41 deletions

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple
import torch import torch
import torch.nn.functional as F
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.feed_forward.encoder import Encoder
@ -12,7 +13,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -30,7 +30,7 @@ class AlignmentEncoder(torch.nn.Module):
self.softmax = torch.nn.Softmax(dim=3) self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3) self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_proj = nn.Sequential( self.key_layer = nn.Sequential(
nn.Conv1d( nn.Conv1d(
in_key_channels, in_key_channels,
in_key_channels * 2, in_key_channels * 2,
@ -42,7 +42,7 @@ class AlignmentEncoder(torch.nn.Module):
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
) )
self.query_proj = nn.Sequential( self.query_layer = nn.Sequential(
nn.Conv1d( nn.Conv1d(
in_query_channels, in_query_channels,
in_query_channels * 2, in_query_channels * 2,
@ -58,33 +58,26 @@ class AlignmentEncoder(torch.nn.Module):
def forward( def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
): ) -> Tuple[torch.tensor, torch.tensor]:
"""Forward pass of the aligner encoder. """Forward pass of the aligner encoder.
Shapes: Shapes:
- queries: :math:`(B, C, T_de)` - queries: :math:`[B, C, T_de]`
- keys: :math:`(B, C_emb, T_en)` - keys: :math:`[B, C_emb, T_en]`
- mask: :math:`(B, T_de)` - mask: :math:`[B, T_de]`
Output: Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1. attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
""" """
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 key_out = self.key_layer(keys)
queries_enc = self.query_proj(queries) query_out = self.query_layer(queries)
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
# Simplistic Gaussian Isotopic Attention attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True)
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
attn = -self.temperature * attn.sum(1, keepdim=True)
if attn_prior is not None: if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8)
attn_logprob = attn.clone()
if mask is not None: if mask is not None:
attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
attn = self.softmax(attn_logp)
attn = self.softmax(attn) # softmax along T2 return attn, attn_logp
return attn, attn_logprob
@dataclass @dataclass
@ -414,23 +407,24 @@ class FastPitch(BaseTTS):
if self.use_aligner: if self.use_aligner:
durations = outputs["o_alignment_dur"] durations = outputs["o_alignment_dur"]
# compute loss with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion( # compute loss
outputs["model_outputs"], loss_dict = criterion(
mel_input, outputs["model_outputs"],
mel_lengths, mel_input,
outputs["durations_log"], mel_lengths,
durations, outputs["durations_log"],
outputs["pitch"], durations,
outputs["pitch_gt"], outputs["pitch"],
text_lengths, outputs["pitch_gt"],
outputs["alignment_logprob"], text_lengths,
) outputs["alignment_logprob"],
)
# compute duration error # compute duration error
durations_pred = outputs["durations"] durations_pred = outputs["durations"]
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
loss_dict["duration_error"] = duration_error loss_dict["duration_error"] = duration_error
return outputs, loss_dict return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use