mirror of https://github.com/coqui-ai/TTS.git
doc strings for the all glow-tts layers
This commit is contained in:
parent
d3b7284be4
commit
e7fad928e7
|
@ -5,6 +5,15 @@ from ..generic.normalization import LayerNorm
|
|||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
"""Glow-TTS duration prediction model.
|
||||
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
|
||||
|
||||
Args:
|
||||
in_channels ([type]): [description]
|
||||
hidden_channels ([type]): [description]
|
||||
kernel_size ([type]): [description]
|
||||
dropout_p ([type]): [description]
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
|
@ -28,6 +37,14 @@ class DurationPredictor(nn.Module):
|
|||
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
x = self.norm_1(x)
|
||||
|
|
|
@ -142,6 +142,12 @@ class Encoder(nn.Module):
|
|||
dropout_p_dp)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_lengths: [B]
|
||||
g (optional): [B, 1, T]
|
||||
"""
|
||||
# embedding layer
|
||||
# [B ,T, D]
|
||||
x = self.emb(x) * math.sqrt(self.hidden_channels)
|
||||
|
|
|
@ -7,8 +7,46 @@ from TTS.tts.layers.glow_tts.glow import LayerNorm
|
|||
|
||||
|
||||
class RelativePositionMultiHeadAttention(nn.Module):
|
||||
"""Implementation of Relative Position Encoding based on
|
||||
"""Multi-head attention with Relative Positional embedding.
|
||||
https://arxiv.org/pdf/1809.04281.pdf
|
||||
|
||||
It learns positional embeddings for a window of neighbours. For keys and values,
|
||||
it learns different set of embeddings. Key embeddings are agregated with the attention
|
||||
scores and value embeddings are aggregated with the output.
|
||||
|
||||
Note:
|
||||
Example with relative attention window size 2
|
||||
input = [a, b, c, d, e]
|
||||
rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
|
||||
|
||||
So it learns 4 embedding vectors (in total 8) separately for key and value vectors.
|
||||
|
||||
Considering the input c
|
||||
e(t-2) corresponds to c -> a
|
||||
e(t-2) corresponds to c -> b
|
||||
e(t-2) corresponds to c -> d
|
||||
e(t-2) corresponds to c -> e
|
||||
|
||||
These embeddings are shared among different time steps. So input a, b, d and e also uses
|
||||
the same embeddings.
|
||||
|
||||
Embeddings are ignored when the relative window is out of limit for the first and the last
|
||||
n items.
|
||||
|
||||
Args:
|
||||
channels (int): input and inner layer channels.
|
||||
out_channels (int): output channels.
|
||||
num_heads (int): number of attention heads.
|
||||
rel_attn_window_size (int, optional): relation attention window size.
|
||||
If 4, for each time step next and previous 4 time steps are attended.
|
||||
If default, relative encoding is disabled and it is a regular transformer.
|
||||
Defaults to None.
|
||||
heads_share (bool, optional): [description]. Defaults to True.
|
||||
dropout_p (float, optional): dropout rate. Defaults to 0..
|
||||
input_length (int, optional): intput length for positional encoding. Defaults to None.
|
||||
proximal_bias (bool, optional): enable/disable proximal bias as in the paper. Defaults to False.
|
||||
proximal_init (bool, optional): enable/disable poximal init as in the paper.
|
||||
Init key and query layer weights the same. Defaults to False.
|
||||
"""
|
||||
def __init__(self,
|
||||
channels,
|
||||
|
@ -20,6 +58,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
input_length=None,
|
||||
proximal_bias=False,
|
||||
proximal_init=False):
|
||||
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0, " [!] channels should be divisible by num_heads."
|
||||
# class attributes
|
||||
|
@ -226,20 +265,28 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
|
||||
|
||||
class FFN(nn.Module):
|
||||
"""Feed Forward Inner layers for Transformer.
|
||||
|
||||
Args:
|
||||
in_channels (int): input tensor channels.
|
||||
out_channels (int): output tensor channels.
|
||||
hidden_channels (int): inner layers hidden channels.
|
||||
kernel_size (int): conv1d filter kernel size.
|
||||
dropout_p (float, optional): dropout rate. Defaults to 0.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dropout_p=0.,
|
||||
activation=None):
|
||||
dropout_p=0.):
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dropout_p = dropout_p
|
||||
self.activation = activation
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels,
|
||||
hidden_channels,
|
||||
|
@ -253,16 +300,29 @@ class FFN(nn.Module):
|
|||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.conv_1(x * x_mask)
|
||||
if self.activation == "gelu":
|
||||
x = x * torch.sigmoid(1.702 * x)
|
||||
else:
|
||||
x = torch.relu(x)
|
||||
x = torch.relu(x)
|
||||
x = self.dropout(x)
|
||||
x = self.conv_2(x * x_mask)
|
||||
return x * x_mask
|
||||
|
||||
|
||||
class RelativePositionTransformer(nn.Module):
|
||||
"""Transformer with Relative Potional Encoding.
|
||||
https://arxiv.org/abs/1803.02155
|
||||
|
||||
Args:
|
||||
hidden_channels (int): model hidden channels.
|
||||
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
|
||||
num_heads (int): number of attention heads.
|
||||
num_layers (int): number of transformer layers.
|
||||
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
|
||||
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
|
||||
rel_attn_window_size (int, optional): relation attention window size.
|
||||
If 4, for each time step next and previous 4 time steps are attended.
|
||||
If default, relative encoding is disabled and it is a regular transformer.
|
||||
Defaults to None.
|
||||
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||
"""
|
||||
def __init__(self,
|
||||
hidden_channels,
|
||||
hidden_channels_ffn,
|
||||
|
@ -305,6 +365,11 @@ class RelativePositionTransformer(nn.Module):
|
|||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
"""
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
for i in range(self.num_layers):
|
||||
x = x * x_mask
|
||||
|
|
|
@ -8,6 +8,17 @@ from TTS.utils.io import RenamingUnpickler
|
|||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||
"""Load ```TTS.tts.models``` checkpoints.
|
||||
|
||||
Args:
|
||||
model (TTS.tts.models): model object to load the weights for.
|
||||
checkpoint_path (string): checkpoint file path.
|
||||
amp (apex.amp, optional): Apex amp abject to load apex related state vars. Defaults to None.
|
||||
use_cuda (bool, optional): load model to GPU if True. Defaults to False.
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
except ModuleNotFoundError:
|
||||
|
@ -26,6 +37,17 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
|||
|
||||
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
||||
"""Save ```TTS.tts.models``` states with extra fields.
|
||||
|
||||
Args:
|
||||
model (TTS.tts.models.Model): models object to be saved.
|
||||
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||
current_step (int): current number of training steps.
|
||||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
model_state = model.module.state_dict()
|
||||
else:
|
||||
|
@ -45,6 +67,16 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
|
|||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||
"""Save model checkpoint, intended for saving checkpoints at training.
|
||||
|
||||
Args:
|
||||
model (TTS.tts.models.Model): models object to be saved.
|
||||
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||
current_step (int): current number of training steps.
|
||||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
"""
|
||||
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||
|
@ -52,6 +84,23 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
|
|||
|
||||
|
||||
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||
"""Save model checkpoint, intended for saving the best model after each epoch.
|
||||
It compares the current model loss with the best loss so far and saves the
|
||||
model if the current loss is better.
|
||||
|
||||
Args:
|
||||
target_loss (float): current model loss.
|
||||
best_loss (float): best loss so far.
|
||||
model (TTS.tts.models.Model): models object to be saved.
|
||||
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||
current_step (int): current number of training steps.
|
||||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
|
||||
Returns:
|
||||
float: updated current best loss.
|
||||
"""
|
||||
if target_loss < best_loss:
|
||||
file_name = 'best_model.pth.tar'
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
|
|
Loading…
Reference in New Issue