diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index a17e1b2b..9ed5dc91 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -24,7 +24,7 @@ class Tacotron(BaseTacotron): a multi-speaker model. Defaults to None. """ - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None): + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): super().__init__(config) self.speaker_manager = speaker_manager diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index e2ae8532..4307c90e 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -1,5 +1,6 @@ # coding: utf-8 +from typing import Dict import torch from coqpit import Coqpit from torch import nn @@ -38,7 +39,7 @@ class Tacotron2(BaseTacotron): Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None. """ - def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None): + def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None): super().__init__(config) self.speaker_manager = speaker_manager @@ -132,11 +133,11 @@ class Tacotron2(BaseTacotron): """Forward pass for training with Teacher Forcing. Shapes: - text: [B, T_in] - text_lengths: [B] - mel_specs: [B, T_out, C] - mel_lengths: [B] - aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C] + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` + mel_specs: :math:`[B, T_out, C]` + mel_lengths: :math:`[B]` + aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]` """ aux_input = self._format_aux_input(aux_input) outputs = {"alignments_backward": None, "decoder_outputs_backward": None} @@ -199,9 +200,9 @@ class Tacotron2(BaseTacotron): def inference(self, text, aux_input=None): """Forward pass for inference with no Teacher-Forcing. - Shapes: - text: :math:`[B, T_in]` - text_lengths: :math:`[B]` + Shapes: + text: :math:`[B, T_in]` + text_lengths: :math:`[B]` """ aux_input = self._format_aux_input(aux_input) embedded_inputs = self.embedding(text).transpose(1, 2) @@ -236,12 +237,12 @@ class Tacotron2(BaseTacotron): } return outputs - def train_step(self, batch, criterion): + def train_step(self, batch:Dict, criterion:torch.nn.Module): """A single training step. Forward pass and loss computation. Args: - batch ([type]): [description] - criterion ([type]): [description] + batch ([Dict]): A dictionary of input tensors. + criterion ([type]): Callable criterion to compute model loss. """ text_input = batch["text_input"] text_lengths = batch["text_lengths"] @@ -296,6 +297,7 @@ class Tacotron2(BaseTacotron): return outputs, loss_dict def _create_logs(self, batch, outputs, ap): + """Create dashboard log information.""" postnet_outputs = outputs["model_outputs"] alignments = outputs["alignments"] alignments_backward = outputs["alignments_backward"] @@ -321,6 +323,7 @@ class Tacotron2(BaseTacotron): def train_log( self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int ) -> None: # pylint: disable=no-self-use + """Log training progress.""" ap = assets["audio_processor"] figures, audios = self._create_logs(batch, outputs, ap) logger.train_figures(steps, figures) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 7561780f..3b7df353 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1,5 +1,4 @@ import math -import os import random from dataclasses import dataclass, field from itertools import chain diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index caed575f..883efdb8 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -23,8 +23,10 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) - mu1_sq = mu1.pow(2) - mu2_sq = mu2.pow(2) + # TODO: check if you need AMP disabled + # with torch.cuda.amp.autocast(enabled=False): + mu1_sq = mu1.float().pow(2) + mu2_sq = mu2.float().pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq