mirror of https://github.com/coqui-ai/TTS.git
Comment Tacotron2 model
This commit is contained in:
parent
92b6d98443
commit
3da79a4de4
|
@ -1,5 +1,6 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -132,11 +133,11 @@ class Tacotron2(BaseTacotron):
|
||||||
"""Forward pass for training with Teacher Forcing.
|
"""Forward pass for training with Teacher Forcing.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
text: [B, T_in]
|
text: :math:`[B, T_in]`
|
||||||
text_lengths: [B]
|
text_lengths: :math:`[B]`
|
||||||
mel_specs: [B, T_out, C]
|
mel_specs: :math:`[B, T_out, C]`
|
||||||
mel_lengths: [B]
|
mel_lengths: :math:`[B]`
|
||||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]`
|
||||||
"""
|
"""
|
||||||
aux_input = self._format_aux_input(aux_input)
|
aux_input = self._format_aux_input(aux_input)
|
||||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||||
|
@ -236,12 +237,12 @@ class Tacotron2(BaseTacotron):
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def train_step(self, batch, criterion):
|
def train_step(self, batch:Dict, criterion:torch.nn.Module):
|
||||||
"""A single training step. Forward pass and loss computation.
|
"""A single training step. Forward pass and loss computation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch ([type]): [description]
|
batch ([Dict]): A dictionary of input tensors.
|
||||||
criterion ([type]): [description]
|
criterion ([type]): Callable criterion to compute model loss.
|
||||||
"""
|
"""
|
||||||
text_input = batch["text_input"]
|
text_input = batch["text_input"]
|
||||||
text_lengths = batch["text_lengths"]
|
text_lengths = batch["text_lengths"]
|
||||||
|
@ -296,6 +297,7 @@ class Tacotron2(BaseTacotron):
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def _create_logs(self, batch, outputs, ap):
|
def _create_logs(self, batch, outputs, ap):
|
||||||
|
"""Create dashboard log information."""
|
||||||
postnet_outputs = outputs["model_outputs"]
|
postnet_outputs = outputs["model_outputs"]
|
||||||
alignments = outputs["alignments"]
|
alignments = outputs["alignments"]
|
||||||
alignments_backward = outputs["alignments_backward"]
|
alignments_backward = outputs["alignments_backward"]
|
||||||
|
@ -321,6 +323,7 @@ class Tacotron2(BaseTacotron):
|
||||||
def train_log(
|
def train_log(
|
||||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||||
) -> None: # pylint: disable=no-self-use
|
) -> None: # pylint: disable=no-self-use
|
||||||
|
"""Log training progress."""
|
||||||
ap = assets["audio_processor"]
|
ap = assets["audio_processor"]
|
||||||
figures, audios = self._create_logs(batch, outputs, ap)
|
figures, audios = self._create_logs(batch, outputs, ap)
|
||||||
logger.train_figures(steps, figures)
|
logger.train_figures(steps, figures)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
|
@ -23,8 +23,10 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
||||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||||
|
|
||||||
mu1_sq = mu1.pow(2)
|
# TODO: check if you need AMP disabled
|
||||||
mu2_sq = mu2.pow(2)
|
# with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
mu1_sq = mu1.float().pow(2)
|
||||||
|
mu2_sq = mu2.float().pow(2)
|
||||||
mu1_mu2 = mu1 * mu2
|
mu1_mu2 = mu1 * mu2
|
||||||
|
|
||||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||||
|
|
Loading…
Reference in New Issue