Update glowtts docstrings and docs

This commit is contained in:
Eren Gölge 2021-06-30 14:30:55 +02:00
parent 21126839a8
commit 2e1a428b83
26 changed files with 305 additions and 225 deletions

View File

@ -2,7 +2,7 @@
Welcome to the 🐸TTS! Welcome to the 🐸TTS!
This repository is governed by the Contributor Covenant Code of Conduct - [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md). This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/coqui-ai/TTS/blob/main/CODE_OF_CONDUCT.md).
## Where to start. ## Where to start.
We welcome everyone who likes to contribute to 🐸TTS. We welcome everyone who likes to contribute to 🐸TTS.

View File

@ -3,7 +3,7 @@
🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality. 🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects. 🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
[![CircleCI](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)]() [![GithubActions](https://github.com/coqui-ai/TTS/actions/workflows/main.yml/badge.svg)](https://github.com/coqui-ai/TTS/actions)
[![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0) [![License](<https://img.shields.io/badge/License-MPL%202.0-brightgreen.svg>)](https://opensource.org/licenses/MPL-2.0)
[![Docs](<https://readthedocs.org/projects/tts/badge/?version=latest&style=plastic>)](https://tts.readthedocs.io/en/latest/) [![Docs](<https://readthedocs.org/projects/tts/badge/?version=latest&style=plastic>)](https://tts.readthedocs.io/en/latest/)
[![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS) [![PyPI version](https://badge.fury.io/py/TTS.svg)](https://badge.fury.io/py/TTS)

View File

@ -985,7 +985,7 @@ def get_last_checkpoint(path):
def process_args(args, config=None): def process_args(args, config=None):
"""Process parsed comand line arguments. """Process parsed comand line arguments and initialize the config if not provided.
Args: Args:
args (argparse.Namespace or dict like): Parsed input arguments. args (argparse.Namespace or dict like): Parsed input arguments.

View File

@ -7,7 +7,7 @@ from TTS.tts.configs.shared_configs import BaseTTSConfig
class GlowTTSConfig(BaseTTSConfig): class GlowTTSConfig(BaseTTSConfig):
"""Defines parameters for GlowTTS model. """Defines parameters for GlowTTS model.
Example: Example:
>>> from TTS.tts.configs import GlowTTSConfig >>> from TTS.tts.configs import GlowTTSConfig
>>> config = GlowTTSConfig() >>> config = GlowTTSConfig()

View File

@ -12,7 +12,8 @@ def squeeze(x, x_mask=None, num_sqz=2):
Note: Note:
each 's' is a n-dimensional vector. each 's' is a n-dimensional vector.
[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]""" ``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]``
"""
b, c, t = x.size() b, c, t = x.size()
t = (t // num_sqz) * num_sqz t = (t // num_sqz) * num_sqz
@ -32,7 +33,8 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
Note: Note:
each 's' is a n-dimensional vector. each 's' is a n-dimensional vector.
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]""" ``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]``
"""
b, c, t = x.size() b, c, t = x.size()
x_unsqz = x.view(b, num_sqz, c // num_sqz, t) x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
@ -47,7 +49,10 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
class Decoder(nn.Module): class Decoder(nn.Module):
"""Stack of Glow Decoder Modules. """Stack of Glow Decoder Modules.
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
::
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
Args: Args:
in_channels (int): channels of input tensor. in_channels (int): channels of input tensor.
@ -106,6 +111,12 @@ class Decoder(nn.Module):
) )
def forward(self, x, x_mask, g=None, reverse=False): def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1 ,T]`
- g: :math:`[B, C]`
"""
if not reverse: if not reverse:
flows = self.flows flows = self.flows
logdet_tot = 0 logdet_tot = 0

View File

@ -6,13 +6,16 @@ from ..generic.normalization import LayerNorm
class DurationPredictor(nn.Module): class DurationPredictor(nn.Module):
"""Glow-TTS duration prediction model. """Glow-TTS duration prediction model.
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
Args: ::
in_channels ([type]): [description]
hidden_channels ([type]): [description] [2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
kernel_size ([type]): [description]
dropout_p ([type]): [description] Args:
in_channels (int): Number of channels of the input tensor.
hidden_channels (int): Number of hidden channels of the network.
kernel_size (int): Kernel size for the conv layers.
dropout_p (float): Dropout rate used after each conv layer.
""" """
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
@ -34,11 +37,8 @@ class DurationPredictor(nn.Module):
def forward(self, x, x_mask): def forward(self, x, x_mask):
""" """
Shapes: Shapes:
x: [B, C, T] - x: :math:`[B, C, T]`
x_mask: [B, 1, T] - x_mask: :math:`[B, 1, T]`
Returns:
[type]: [description]
""" """
x = self.conv_1(x * x_mask) x = self.conv_1(x * x_mask)
x = torch.relu(x) x = torch.relu(x)

View File

@ -15,13 +15,16 @@ from TTS.tts.utils.data import sequence_mask
class Encoder(nn.Module): class Encoder(nn.Module):
"""Glow-TTS encoder module. """Glow-TTS encoder module.
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean ::
|
|-> proj_var embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
| |
|-> concat -> duration_predictor |-> proj_var
|
speaker_embed |-> concat -> duration_predictor
speaker_embed
Args: Args:
num_chars (int): number of characters. num_chars (int): number of characters.
out_channels (int): number of output channels. out_channels (int): number of output channels.
@ -36,7 +39,8 @@ class Encoder(nn.Module):
Shapes: Shapes:
- input: (B, T, C) - input: (B, T, C)
Notes: ::
suggested encoder params... suggested encoder params...
for encoder_type == 'rel_pos_transformer' for encoder_type == 'rel_pos_transformer'
@ -139,9 +143,9 @@ class Encoder(nn.Module):
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
""" """
Shapes: Shapes:
x: [B, C, T] - x: :math:`[B, C, T]`
x_lengths: [B] - x_lengths: :math:`[B]`
g (optional): [B, 1, T] - g (optional): :math:`[B, 1, T]`
""" """
# embedding layer # embedding layer
# [B ,T, D] # [B ,T, D]

View File

@ -10,21 +10,24 @@ from ..generic.normalization import LayerNorm
class ResidualConv1dLayerNormBlock(nn.Module): class ResidualConv1dLayerNormBlock(nn.Module):
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
::
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|---------------> conv1d_1x1 -----------------------|
Args:
in_channels (int): number of input tensor channels.
hidden_channels (int): number of inner layer channels.
out_channels (int): number of output tensor channels.
kernel_size (int): kernel size of conv1d filter.
num_layers (int): number of blocks.
dropout_p (float): dropout rate for each block.
"""
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p): def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p):
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|---------------> conv1d_1x1 -----------------------|
Args:
in_channels (int): number of input tensor channels.
hidden_channels (int): number of inner layer channels.
out_channels (int): number of output tensor channels.
kernel_size (int): kernel size of conv1d filter.
num_layers (int): number of blocks.
dropout_p (float): dropout rate for each block.
"""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
@ -51,6 +54,11 @@ class ResidualConv1dLayerNormBlock(nn.Module):
self.proj.bias.data.zero_() self.proj.bias.data.zero_()
def forward(self, x, x_mask): def forward(self, x, x_mask):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1, T]`
"""
x_res = x x_res = x
for i in range(self.num_layers): for i in range(self.num_layers):
x = self.conv_layers[i](x * x_mask) x = self.conv_layers[i](x * x_mask)
@ -95,8 +103,8 @@ class InvConvNear(nn.Module):
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
x: B x C x T - x: :math:`[B, C, T]`
x_mask: B x 1 x T - x_mask: :math:`[B, 1, T]`
""" """
b, c, t = x.size() b, c, t = x.size()
@ -139,10 +147,12 @@ class CouplingBlock(nn.Module):
"""Glow Affine Coupling block as in GlowTTS paper. """Glow Affine Coupling block as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf https://arxiv.org/pdf/1811.00002.pdf
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o ::
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
Args: x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
Args:
in_channels (int): number of input tensor channels. in_channels (int): number of input tensor channels.
hidden_channels (int): number of hidden channels. hidden_channels (int): number of hidden channels.
kernel_size (int): WaveNet filter kernel size. kernel_size (int): WaveNet filter kernel size.
@ -152,8 +162,8 @@ class CouplingBlock(nn.Module):
dropout_p (int): wavenet dropout rate. dropout_p (int): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
Note: Note:
It does not use conditional inputs differently from WaveGlow. It does not use the conditional inputs differently from WaveGlow.
""" """
def __init__( def __init__(
@ -193,9 +203,9 @@ class CouplingBlock(nn.Module):
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
""" """
Shapes: Shapes:
x: B x C x T - x: :math:`[B, C, T]`
x_mask: B x 1 x T - x_mask: :math:`[B, 1, T]`
g: B x C x 1 - g: :math:`[B, C, 1]`
""" """
if x_mask is None: if x_mask is None:
x_mask = 1 x_mask = 1

View File

@ -17,16 +17,18 @@ class RelativePositionMultiHeadAttention(nn.Module):
Note: Note:
Example with relative attention window size 2 Example with relative attention window size 2
input = [a, b, c, d, e]
rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)] - input = [a, b, c, d, e]
- rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
So it learns 4 embedding vectors (in total 8) separately for key and value vectors. So it learns 4 embedding vectors (in total 8) separately for key and value vectors.
Considering the input c Considering the input c
e(t-2) corresponds to c -> a
e(t-2) corresponds to c -> b - e(t-2) corresponds to c -> a
e(t-2) corresponds to c -> d - e(t-2) corresponds to c -> b
e(t-2) corresponds to c -> e - e(t-2) corresponds to c -> d
- e(t-2) corresponds to c -> e
These embeddings are shared among different time steps. So input a, b, d and e also uses These embeddings are shared among different time steps. So input a, b, d and e also uses
the same embeddings. the same embeddings.
@ -106,6 +108,12 @@ class RelativePositionMultiHeadAttention(nn.Module):
nn.init.xavier_uniform_(self.conv_v.weight) nn.init.xavier_uniform_(self.conv_v.weight)
def forward(self, x, c, attn_mask=None): def forward(self, x, c, attn_mask=None):
"""
Shapes:
- x: :math:`[B, C, T]`
- c: :math:`[B, C, T]`
- attn_mask: :math:`[B, 1, T, T]`
"""
q = self.conv_q(x) q = self.conv_q(x)
k = self.conv_k(c) k = self.conv_k(c)
v = self.conv_v(c) v = self.conv_v(c)
@ -163,9 +171,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
re (Tensor): relative value embedding vector. (a_(i,j)^V) re (Tensor): relative value embedding vector. (a_(i,j)^V)
Shapes: Shapes:
p_attn: [B, H, T, V] -p_attn: :math:`[B, H, T, V]`
re: [H or 1, V, D] -re: :math:`[H or 1, V, D]`
logits: [B, H, T, D] -logits: :math:`[B, H, T, D]`
""" """
logits = torch.matmul(p_attn, re.unsqueeze(0)) logits = torch.matmul(p_attn, re.unsqueeze(0))
return logits return logits
@ -178,9 +186,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
re (Tensor): relative key embedding vector. (a_(i,j)^K) re (Tensor): relative key embedding vector. (a_(i,j)^K)
Shapes: Shapes:
query: [B, H, T, D] - query: :math:`[B, H, T, D]`
re: [H or 1, V, D] - re: :math:`[H or 1, V, D]`
logits: [B, H, T, V] - logits: :math:`[B, H, T, V]`
""" """
# logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)]) # logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)])
logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1)) logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1))
@ -202,10 +210,10 @@ class RelativePositionMultiHeadAttention(nn.Module):
@staticmethod @staticmethod
def _relative_position_to_absolute_position(x): def _relative_position_to_absolute_position(x):
"""Converts tensor from relative to absolute indexing for local attention. """Converts tensor from relative to absolute indexing for local attention.
Args: Shapes:
x: [B, D, length, 2 * length - 1] x: :math:`[B, C, T, 2 * T - 1]`
Returns: Returns:
A Tensor of shape [B, D, length, length] A Tensor of shape :math:`[B, C, T, T]`
""" """
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# Pad to shift from relative to absolute indexing. # Pad to shift from relative to absolute indexing.
@ -220,8 +228,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
@staticmethod @staticmethod
def _absolute_position_to_relative_position(x): def _absolute_position_to_relative_position(x):
""" """
x: [B, H, T, T] Shapes:
ret: [B, H, T, 2*T-1] - x: :math:`[B, C, T, T]`
- ret: :math:`[B, C, T, 2*T-1]`
""" """
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# padd along column # padd along column
@ -239,7 +248,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
Args: Args:
length (int): an integer scalar. length (int): an integer scalar.
Returns: Returns:
a Tensor with shape [1, 1, length, length] a Tensor with shape :math:`[1, 1, T, T]`
""" """
# L # L
r = torch.arange(length, dtype=torch.float32) r = torch.arange(length, dtype=torch.float32)
@ -362,8 +371,8 @@ class RelativePositionTransformer(nn.Module):
def forward(self, x, x_mask): def forward(self, x, x_mask):
""" """
Shapes: Shapes:
x: [B, C, T] - x: :math:`[B, C, T]`
x_mask: [B, 1, T] - x_mask: :math:`[B, 1, T]`
""" """
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.num_layers): for i in range(self.num_layers):

View File

@ -30,24 +30,31 @@ class GlowTTS(BaseTTS):
the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our
model can be easily extended to a multi-speaker setting. model can be easily extended to a multi-speaker setting.
Check `GlowTTSConfig` for class arguments. Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
Examples:
>>> from TTS.tts.configs import GlowTTSConfig
>>> from TTS.tts.models.glow_tts import GlowTTS
>>> config = GlowTTSConfig()
>>> model = GlowTTS(config)
""" """
def __init__(self, config: GlowTTSConfig): def __init__(self, config: GlowTTSConfig):
super().__init__() super().__init__()
chars, self.config = self.get_characters(config)
self.num_chars = len(chars)
self.decoder_output_dim = config.out_channels
self.init_multispeaker(config)
# pass all config fields to `self` # pass all config fields to `self`
# for fewer code change # for fewer code change
self.config = config self.config = config
for key in config: for key in config:
setattr(self, key, config[key]) setattr(self, key, config[key])
chars, self.config = self.get_characters(config)
self.num_chars = len(chars)
self.decoder_output_dim = config.out_channels
self.init_multispeaker(config)
# if is a multispeaker and c_in_channels is 0, set to 256 # if is a multispeaker and c_in_channels is 0, set to 256
self.c_in_channels = 0 self.c_in_channels = 0
if self.num_speakers > 1: if self.num_speakers > 1:
@ -91,7 +98,7 @@ class GlowTTS(BaseTTS):
@staticmethod @staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask): def compute_outputs(attn, o_mean, o_log_scale, x_mask):
# compute final values with the computed alignment """ Compute and format the mode outputs with the given alignment map"""
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2 1, 2
) # [b, t', t], [b, t, d] -> [b, d, t'] ) # [b, t', t], [b, t, d] -> [b, d, t']
@ -107,11 +114,11 @@ class GlowTTS(BaseTTS):
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
""" """
Shapes: Shapes:
x: [B, T] - x: :math:`[B, T]`
x_lenghts: B - x_lenghts::math:` B`
y: [B, T, C] - y: :math:`[B, T, C]`
y_lengths: B - y_lengths::math:` B`
g: [B, C] or B - g: :math:`[B, C] or B`
""" """
y = y.transpose(1, 2) y = y.transpose(1, 2)
y_max_length = y.size(2) y_max_length = y.size(2)
@ -161,12 +168,13 @@ class GlowTTS(BaseTTS):
""" """
It's similar to the teacher forcing in Tacotron. It's similar to the teacher forcing in Tacotron.
It was proposed in: https://arxiv.org/abs/2104.05557 It was proposed in: https://arxiv.org/abs/2104.05557
Shapes: Shapes:
x: [B, T] - x: :math:`[B, T]`
x_lenghts: B - x_lenghts: :math:`B`
y: [B, T, C] - y: :math:`[B, T, C]`
y_lengths: B - y_lengths: :math:`B`
g: [B, C] or B - g: :math:`[B, C] or B`
""" """
y = y.transpose(1, 2) y = y.transpose(1, 2)
y_max_length = y.size(2) y_max_length = y.size(2)
@ -221,9 +229,9 @@ class GlowTTS(BaseTTS):
): # pylint: disable=dangerous-default-value ): # pylint: disable=dangerous-default-value
""" """
Shapes: Shapes:
y: [B, T, C] - y: :math:`[B, T, C]`
y_lengths: B - y_lengths: :math:`B`
g: [B, C] or B - g: :math:`[B, C] or B`
""" """
y = y.transpose(1, 2) y = y.transpose(1, 2)
y_max_length = y.size(2) y_max_length = y.size(2)

View File

@ -54,7 +54,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
Tensor: spectrogram frames. Tensor: spectrogram frames.
Shapes: Shapes:
x: [B x T] or [B x 1 x T] x: [B x T] or [:math:`[B, 1, T]`]
""" """
if x.ndim == 2: if x.ndim == 2:
x = x.unsqueeze(1) x = x.unsqueeze(1)

View File

@ -22,6 +22,9 @@ class GAN(BaseVocoder):
"""Wrap a generator and a discriminator network. It provides a compatible interface for the trainer. """Wrap a generator and a discriminator network. It provides a compatible interface for the trainer.
It also helps mixing and matching different generator and disciminator networks easily. It also helps mixing and matching different generator and disciminator networks easily.
To implement a new GAN models, you just need to define the generator and the discriminator networks, the rest
is handled by the `GAN` class.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
@ -39,12 +42,41 @@ class GAN(BaseVocoder):
self.y_hat_g = None # the last generator prediction to be passed onto the discriminator self.y_hat_g = None # the last generator prediction to be passed onto the discriminator
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's forward pass.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: output of the GAN generator network.
"""
return self.model_g.forward(x) return self.model_g.forward(x)
def inference(self, x: torch.Tensor) -> torch.Tensor: def inference(self, x: torch.Tensor) -> torch.Tensor:
"""Run the generator's inference pass.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: output of the GAN generator network.
"""
return self.model_g.inference(x) return self.model_g.inference(x)
def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]: def train_step(self, batch: Dict, criterion: Dict, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Compute model outputs and the loss values. `optimizer_idx` selects the generator or the discriminator for
network on the current pass.
Args:
batch (Dict): Batch of samples returned by the dataloader.
criterion (Dict): Criterion used to compute the losses.
optimizer_idx (int): ID of the optimizer in use on the current pass.
Raises:
ValueError: `optimizer_idx` is an unexpected value.
Returns:
Tuple[Dict, Dict]: model outputs and the computed loss values.
"""
outputs = None outputs = None
loss_dict = None loss_dict = None
@ -145,7 +177,18 @@ class GAN(BaseVocoder):
return outputs, loss_dict return outputs, loss_dict
@staticmethod @staticmethod
def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def _log(name: str, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, Dict]:
"""Logging shared by the training and evaluation.
Args:
name (str): Name of the run. `train` or `eval`,
ap (AudioProcessor): Audio processor used in training.
batch (Dict): Batch used in the last train/eval step.
outputs (Dict): Model outputs from the last train/eval step.
Returns:
Tuple[Dict, Dict]: log figures and audio samples.
"""
y_hat = outputs[0]["model_outputs"] y_hat = outputs[0]["model_outputs"]
y = batch["waveform"] y = batch["waveform"]
figures = plot_results(y_hat, y, ap, name) figures = plot_results(y_hat, y, ap, name)
@ -154,13 +197,16 @@ class GAN(BaseVocoder):
return figures, audios return figures, audios
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for training."""
return self._log("train", ap, batch, outputs) return self._log("train", ap, batch, outputs)
@torch.no_grad() @torch.no_grad()
def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]: def eval_step(self, batch: Dict, criterion: nn.Module, optimizer_idx: int) -> Tuple[Dict, Dict]:
"""Call `train_step()` with `no_grad()`"""
return self.train_step(batch, criterion, optimizer_idx) return self.train_step(batch, criterion, optimizer_idx)
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]: def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
"""Call `_log()` for evaluation."""
return self._log("eval", ap, batch, outputs) return self._log("eval", ap, batch, outputs)
def load_checkpoint( def load_checkpoint(
@ -169,6 +215,13 @@ class GAN(BaseVocoder):
checkpoint_path: str, checkpoint_path: str,
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
) -> None: ) -> None:
"""Load a GAN checkpoint and initialize model parameters.
Args:
config (Coqpit): Model config.
checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
# band-aid for older than v0.0.15 GAN models # band-aid for older than v0.0.15 GAN models
if "model_disc" in state: if "model_disc" in state:
@ -181,9 +234,21 @@ class GAN(BaseVocoder):
self.model_g.remove_weight_norm() self.model_g.remove_weight_norm()
def on_train_step_start(self, trainer) -> None: def on_train_step_start(self, trainer) -> None:
"""Enable the discriminator training based on `steps_to_start_discriminator`
Args:
trainer (Trainer): Trainer object.
"""
self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator self.train_disc = trainer.total_steps_done >= self.config.steps_to_start_discriminator
def get_optimizer(self): def get_optimizer(self) -> List:
"""Initiate and return the GAN optimizers based on the config parameters.
It returnes 2 optimizers in a list. First one is for the generator and the second one is for the discriminator.
Returns:
List: optimizers.
"""
optimizer1 = get_optimizer( optimizer1 = get_optimizer(
self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g self.config.optimizer, self.config.optimizer_params, self.config.lr_gen, self.model_g
) )
@ -192,16 +257,37 @@ class GAN(BaseVocoder):
) )
return [optimizer1, optimizer2] return [optimizer1, optimizer2]
def get_lr(self): def get_lr(self) -> List:
"""Set the initial learning rates for each optimizer.
Returns:
List: learning rates for each optimizer.
"""
return [self.config.lr_gen, self.config.lr_disc] return [self.config.lr_gen, self.config.lr_disc]
def get_scheduler(self, optimizer): def get_scheduler(self, optimizer) -> List:
"""Set the schedulers for each optimizer.
Args:
optimizer (List[`torch.optim.Optimizer`]): List of optimizers.
Returns:
List: Schedulers, one for each optimizer.
"""
scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0]) scheduler1 = get_scheduler(self.config.lr_scheduler_gen, self.config.lr_scheduler_gen_params, optimizer[0])
scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1]) scheduler2 = get_scheduler(self.config.lr_scheduler_disc, self.config.lr_scheduler_disc_params, optimizer[1])
return [scheduler1, scheduler2] return [scheduler1, scheduler2]
@staticmethod @staticmethod
def format_batch(batch): def format_batch(batch: List) -> Dict:
"""Format the batch for training.
Args:
batch (List): Batch out of the dataloader.
Returns:
Dict: formatted model inputs.
"""
if isinstance(batch[0], list): if isinstance(batch[0], list):
x_G, y_G = batch[0] x_G, y_G = batch[0]
x_D, y_D = batch[1] x_D, y_D = batch[1]
@ -218,6 +304,19 @@ class GAN(BaseVocoder):
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
): ):
"""Initiate and return the GAN dataloader.
Args:
config (Coqpit): Model config.
ap (AudioProcessor): Audio processor.
is_eval (True): Set the dataloader for evaluation if true.
data_items (List): Data samples.
verbose (bool): Log information if true.
num_gpus (int): Number of GPUs in use.
Returns:
DataLoader: Torch dataloader.
"""
dataset = GANDataset( dataset = GANDataset(
ap=ap, ap=ap,
items=data_items, items=data_items,

View File

@ -34,7 +34,7 @@ class PQMF(tf.keras.layers.Layer):
def analysis(self, x): def analysis(self, x):
""" """
x : B x 1 x T x : :math:`[B, 1, T]`
""" """
x = tf.transpose(x, perm=[0, 2, 1]) x = tf.transpose(x, perm=[0, 2, 1])
x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0)

View File

@ -92,7 +92,7 @@ class MelganGenerator(tf.keras.models.Model):
@tf.function(experimental_relax_shapes=True) @tf.function(experimental_relax_shapes=True)
def call(self, c, training=False): def call(self, c, training=False):
""" """
c : B x C x T c : :math:`[B, C, T]`
""" """
if training: if training:
raise NotImplementedError() raise NotImplementedError()

View File

@ -113,7 +113,7 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
""" """
Sample from discretized mixture of logistic distributions Sample from discretized mixture of logistic distributions
Args: Args:
y (Tensor): B x C x T y (Tensor): :math:`[B, C, T]`
log_scale_min (float): Log scale minimum value log_scale_min (float): Log scale minimum value
Returns: Returns:
Tensor: sample in range of [-1, 1]. Tensor: sample in range of [-1, 1].

View File

@ -1,25 +0,0 @@
# AudioProcessor
`TTS.utils.audio.AudioProcessor` is the core class for all the audio processing routines. It provides an API for
- Feature extraction.
- Sound normalization.
- Reading and writing audio files.
- Sampling audio signals.
- Normalizing and denormalizing audio signals.
- Griffin-Lim vocoder.
The `AudioProcessor` needs to be initialized with `TTS.config.shared_configs.BaseAudioConfig`. Any model config
also must inherit or initiate `BaseAudioConfig`.
## AudioProcessor
```{eval-rst}
.. autoclass:: TTS.utils.audio.AudioProcessor
:members:
```
## BaseAudioConfig
```{eval-rst}
.. autoclass:: TTS.config.shared_configs.BaseAudioConfig
:members:
```

View File

@ -50,6 +50,43 @@ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'TODO/*']
source_suffix = [".rst", ".md"] source_suffix = [".rst", ".md"]
# extensions
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.autosectionlabel',
'myst_parser',
"sphinx_copybutton",
"sphinx_inline_tabs",
]
# 'sphinxcontrib.katex',
# 'sphinx.ext.autosectionlabel',
# autosectionlabel throws warnings if section names are duplicated.
# The following tells autosectionlabel to not throw a warning for
# duplicated section names that are in different documents.
autosectionlabel_prefix_document = True
language = None
autodoc_inherit_docstrings = False
# Disable displaying type annotations, these can be very verbose
autodoc_typehints = 'none'
# Enable overriding of function signatures in the first line of the docstring.
autodoc_docstring_signature = True
napoleon_custom_sections = [('Shapes', 'shape')]
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
@ -80,23 +117,3 @@ html_sidebars = {
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ['_static']
# using markdown
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.autosectionlabel',
'myst_parser',
"sphinx_copybutton",
"sphinx_inline_tabs",
]
# 'sphinxcontrib.katex',
# 'sphinx.ext.autosectionlabel',

View File

@ -1,4 +1,4 @@
# Converting Torch Tacotron to TF 2 # Converting Torch to TF 2
Currently, 🐸TTS supports the vanilla Tacotron2 and MelGAN models in TF 2.It does not support advanced attention methods and other small tricks used by the Torch models. You can convert any Torch model trained after v0.0.2. Currently, 🐸TTS supports the vanilla Tacotron2 and MelGAN models in TF 2.It does not support advanced attention methods and other small tricks used by the Torch models. You can convert any Torch model trained after v0.0.2.

View File

@ -1,25 +0,0 @@
# Datasets
## TTS Dataset
```{eval-rst}
.. autoclass:: TTS.tts.datasets.TTSDataset
:members:
```
## Vocoder Dataset
```{eval-rst}
.. autoclass:: TTS.vocoder.datasets.gan_dataset.GANDataset
:members:
```
```{eval-rst}
.. autoclass:: TTS.vocoder.datasets.wavegrad_dataset.WaveGradDataset
:members:
```
```{eval-rst}
.. autoclass:: TTS.vocoder.datasets.wavernn_dataset.WaveRNNDataset
:members:
```

View File

@ -105,7 +105,7 @@ The best approach is to pick a set of promising models and run a Mean-Opinion-Sc
- Check the 4th step under "How can I check model performance?" - Check the 4th step under "How can I check model performance?"
## How can I test a trained model? ## How can I test a trained model?
- The best way is to use `tts` or `tts-server` commands. For details check {ref}`here <Synthesizing Speech>`. - The best way is to use `tts` or `tts-server` commands. For details check {ref}`here <synthesizing_speech>`.
- If you need to code your own ```TTS.utils.synthesizer.Synthesizer``` class. - If you need to code your own ```TTS.utils.synthesizer.Synthesizer``` class.
## My Tacotron model does not stop - I see "Decoder stopped with 'max_decoder_steps" - Stopnet does not work. ## My Tacotron model does not stop - I see "Decoder stopped with 'max_decoder_steps" - Stopnet does not work.

View File

@ -36,7 +36,7 @@
There is also the `callback` interface by which you can manipulate both the model and the `Trainer` states. Callbacks give you There is also the `callback` interface by which you can manipulate both the model and the `Trainer` states. Callbacks give you
the infinite flexibility to add custom behaviours for your model and training routines. the infinite flexibility to add custom behaviours for your model and training routines.
For more details, see {ref}`BaseTTS <Base TTS Model>` and `TTS/utils/callbacks.py`. For more details, see {ref}`BaseTTS <Base TTS Model>` and :obj:`TTS.utils.callbacks`.
6. Optionally, define `MyModelArgs`. 6. Optionally, define `MyModelArgs`.

View File

@ -2,7 +2,6 @@
```{include} ../../README.md ```{include} ../../README.md
:relative-images: :relative-images:
``` ```
---- ----
# Documentation Content # Documentation Content
@ -27,14 +26,28 @@
formatting_your_dataset formatting_your_dataset
what_makes_a_good_dataset what_makes_a_good_dataset
tts_datasets tts_datasets
converting_torch_to_tf
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
:caption: Main Classes :caption: Main Classes
trainer_api main_classes/trainer_api
audio_processor main_classes/audio_processor
model_api main_classes/model_api
configuration main_classes/dataset
dataset main_classes/gan
```
.. toctree::
:maxdepth: 2
:caption: `tts` Models
models/glow_tts.md
.. toctree::
:maxdepth: 2
:caption: `vocoder` Models
main_classes/gan
```

View File

@ -1,4 +1,4 @@
# AudioProcessor # AudioProcessor API
`TTS.utils.audio.AudioProcessor` is the core class for all the audio processing routines. It provides an API for `TTS.utils.audio.AudioProcessor` is the core class for all the audio processing routines. It provides an API for

View File

@ -19,6 +19,6 @@ Model API provides you a set of functions that easily make your model compatible
## Base `vocoder` Model ## Base `vocoder` Model
```{eval-rst} ```{eval-rst}
.. autoclass:: TTS.tts.models.base_vocoder.BaseVocoder` .. autoclass:: TTS.vocoder.models.base_vocoder.BaseVocoder
:members: :members:
``` ```

View File

@ -1,24 +0,0 @@
# Model API
Model API provides you a set of functions that easily make your model compatible with the `Trainer`,
`Synthesizer` and `ModelZoo`.
## Base TTS Model
```{eval-rst}
.. autoclass:: TTS.model.BaseModel
:members:
```
## Base `tts` Model
```{eval-rst}
.. autoclass:: TTS.tts.models.base_tts.BaseTTS
:members:
```
## Base `vocoder` Model
```{eval-rst}
.. autoclass:: TTS.tts.models.base_vocoder.BaseVocoder`
:members:
```

View File

@ -1,17 +0,0 @@
# Trainer API
The {class}`TTS.trainer.Trainer` provides a lightweight, extensible, and feature-complete training run-time. We optimized it for 🐸 but
can also be used for any DL training in different domains. It supports distributed multi-gpu, mixed-precision (apex or torch.amp) training.
## Trainer
```{eval-rst}
.. autoclass:: TTS.trainer.Trainer
:members:
```
## TrainingArgs
```{eval-rst}
.. autoclass:: TTS.trainer.TrainingArgs
:members:
```