Disable autcast for criterions

This commit is contained in:
Eren Gölge 2021-09-01 08:31:33 +00:00
parent 98a7271ce8
commit 59d52a4cd8
1 changed files with 35 additions and 41 deletions

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field
from typing import Tuple
import torch
import torch.nn.functional as F
from coqpit import Coqpit
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
@ -12,7 +13,6 @@ from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
@ -30,7 +30,7 @@ class AlignmentEncoder(torch.nn.Module):
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_proj = nn.Sequential(
self.key_layer = nn.Sequential(
nn.Conv1d(
in_key_channels,
in_key_channels * 2,
@ -42,7 +42,7 @@ class AlignmentEncoder(torch.nn.Module):
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
)
self.query_proj = nn.Sequential(
self.query_layer = nn.Sequential(
nn.Conv1d(
in_query_channels,
in_query_channels * 2,
@ -58,33 +58,26 @@ class AlignmentEncoder(torch.nn.Module):
def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
):
) -> Tuple[torch.tensor, torch.tensor]:
"""Forward pass of the aligner encoder.
Shapes:
- queries: :math:`(B, C, T_de)`
- keys: :math:`(B, C_emb, T_en)`
- mask: :math:`(B, T_de)`
- queries: :math:`[B, C, T_de]`
- keys: :math:`[B, C_emb, T_en]`
- mask: :math:`[B, T_de]`
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1.
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
"""
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
queries_enc = self.query_proj(queries)
# Simplistic Gaussian Isotopic Attention
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
attn = -self.temperature * attn.sum(1, keepdim=True)
key_out = self.key_layer(keys)
query_out = self.query_layer(queries)
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)
attn_logprob = attn.clone()
attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8)
if mask is not None:
attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
attn = self.softmax(attn) # softmax along T2
return attn, attn_logprob
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
attn = self.softmax(attn_logp)
return attn, attn_logp
@dataclass
@ -414,23 +407,24 @@ class FastPitch(BaseTTS):
if self.use_aligner:
durations = outputs["o_alignment_dur"]
# compute loss
loss_dict = criterion(
outputs["model_outputs"],
mel_input,
mel_lengths,
outputs["durations_log"],
durations,
outputs["pitch"],
outputs["pitch_gt"],
text_lengths,
outputs["alignment_logprob"],
)
with autocast(enabled=False): # use float32 for the criterion
# compute loss
loss_dict = criterion(
outputs["model_outputs"],
mel_input,
mel_lengths,
outputs["durations_log"],
durations,
outputs["pitch"],
outputs["pitch_gt"],
text_lengths,
outputs["alignment_logprob"],
)
# compute duration error
durations_pred = outputs["durations"]
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
loss_dict["duration_error"] = duration_error
# compute duration error
durations_pred = outputs["durations"]
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
loss_dict["duration_error"] = duration_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use