From 59d52a4cd8659fea158aeec925bc0ffa6f505694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 1 Sep 2021 08:31:33 +0000 Subject: [PATCH] Disable autcast for criterions --- TTS/tts/models/fast_pitch.py | 76 +++++++++++++++++------------------- 1 file changed, 35 insertions(+), 41 deletions(-) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index c218535e..b8f346c7 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -1,9 +1,10 @@ from dataclasses import dataclass, field +from typing import Tuple import torch -import torch.nn.functional as F from coqpit import Coqpit from torch import nn +from torch.cuda.amp.autocast_mode import autocast from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.encoder import Encoder @@ -12,7 +13,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask -from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor @@ -30,7 +30,7 @@ class AlignmentEncoder(torch.nn.Module): self.softmax = torch.nn.Softmax(dim=3) self.log_softmax = torch.nn.LogSoftmax(dim=3) - self.key_proj = nn.Sequential( + self.key_layer = nn.Sequential( nn.Conv1d( in_key_channels, in_key_channels * 2, @@ -42,7 +42,7 @@ class AlignmentEncoder(torch.nn.Module): nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), ) - self.query_proj = nn.Sequential( + self.query_layer = nn.Sequential( nn.Conv1d( in_query_channels, in_query_channels * 2, @@ -58,33 +58,26 @@ class AlignmentEncoder(torch.nn.Module): def forward( self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None - ): + ) -> Tuple[torch.tensor, torch.tensor]: """Forward pass of the aligner encoder. Shapes: - - queries: :math:`(B, C, T_de)` - - keys: :math:`(B, C_emb, T_en)` - - mask: :math:`(B, T_de)` + - queries: :math:`[B, C, T_de]` + - keys: :math:`[B, C_emb, T_en]` + - mask: :math:`[B, T_de]` Output: - attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1. - attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. + attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. + attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. """ - keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 - queries_enc = self.query_proj(queries) - - # Simplistic Gaussian Isotopic Attention - attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2 - attn = -self.temperature * attn.sum(1, keepdim=True) - + key_out = self.key_layer(keys) + query_out = self.query_layer(queries) + attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 + attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True) if attn_prior is not None: - attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) - - attn_logprob = attn.clone() - + attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8) if mask is not None: - attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) - - attn = self.softmax(attn) # softmax along T2 - return attn, attn_logprob + attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) + attn = self.softmax(attn_logp) + return attn, attn_logp @dataclass @@ -414,23 +407,24 @@ class FastPitch(BaseTTS): if self.use_aligner: durations = outputs["o_alignment_dur"] - # compute loss - loss_dict = criterion( - outputs["model_outputs"], - mel_input, - mel_lengths, - outputs["durations_log"], - durations, - outputs["pitch"], - outputs["pitch_gt"], - text_lengths, - outputs["alignment_logprob"], - ) + with autocast(enabled=False): # use float32 for the criterion + # compute loss + loss_dict = criterion( + outputs["model_outputs"], + mel_input, + mel_lengths, + outputs["durations_log"], + durations, + outputs["pitch"], + outputs["pitch_gt"], + text_lengths, + outputs["alignment_logprob"], + ) - # compute duration error - durations_pred = outputs["durations"] - duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() - loss_dict["duration_error"] = duration_error + # compute duration error + durations_pred = outputs["durations"] + duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() + loss_dict["duration_error"] = duration_error return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use