Disable autcast for criterions

This commit is contained in:
Eren Gölge 2021-09-01 08:31:33 +00:00
parent 98a7271ce8
commit 59d52a4cd8
1 changed files with 35 additions and 41 deletions

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple
import torch import torch
import torch.nn.functional as F
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.feed_forward.encoder import Encoder
@ -12,7 +13,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -30,7 +30,7 @@ class AlignmentEncoder(torch.nn.Module):
self.softmax = torch.nn.Softmax(dim=3) self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3) self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_proj = nn.Sequential( self.key_layer = nn.Sequential(
nn.Conv1d( nn.Conv1d(
in_key_channels, in_key_channels,
in_key_channels * 2, in_key_channels * 2,
@ -42,7 +42,7 @@ class AlignmentEncoder(torch.nn.Module):
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
) )
self.query_proj = nn.Sequential( self.query_layer = nn.Sequential(
nn.Conv1d( nn.Conv1d(
in_query_channels, in_query_channels,
in_query_channels * 2, in_query_channels * 2,
@ -58,33 +58,26 @@ class AlignmentEncoder(torch.nn.Module):
def forward( def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
): ) -> Tuple[torch.tensor, torch.tensor]:
"""Forward pass of the aligner encoder. """Forward pass of the aligner encoder.
Shapes: Shapes:
- queries: :math:`(B, C, T_de)` - queries: :math:`[B, C, T_de]`
- keys: :math:`(B, C_emb, T_en)` - keys: :math:`[B, C_emb, T_en]`
- mask: :math:`(B, T_de)` - mask: :math:`[B, T_de]`
Output: Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1. attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask. attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
""" """
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2 key_out = self.key_layer(keys)
queries_enc = self.query_proj(queries) query_out = self.query_layer(queries)
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
# Simplistic Gaussian Isotopic Attention attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True)
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
attn = -self.temperature * attn.sum(1, keepdim=True)
if attn_prior is not None: if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8) attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8)
attn_logprob = attn.clone()
if mask is not None: if mask is not None:
attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
attn = self.softmax(attn_logp)
attn = self.softmax(attn) # softmax along T2 return attn, attn_logp
return attn, attn_logprob
@dataclass @dataclass
@ -414,6 +407,7 @@ class FastPitch(BaseTTS):
if self.use_aligner: if self.use_aligner:
durations = outputs["o_alignment_dur"] durations = outputs["o_alignment_dur"]
with autocast(enabled=False): # use float32 for the criterion
# compute loss # compute loss
loss_dict = criterion( loss_dict = criterion(
outputs["model_outputs"], outputs["model_outputs"],