doc strings for the all glow-tts layers

This commit is contained in:
erogol 2021-01-06 12:36:52 +01:00
parent d3b7284be4
commit e7fad928e7
4 changed files with 145 additions and 8 deletions

View File

@ -5,6 +5,15 @@ from ..generic.normalization import LayerNorm
class DurationPredictor(nn.Module): class DurationPredictor(nn.Module):
"""Glow-TTS duration prediction model.
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
Args:
in_channels ([type]): [description]
hidden_channels ([type]): [description]
kernel_size ([type]): [description]
dropout_p ([type]): [description]
"""
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
super().__init__() super().__init__()
# class arguments # class arguments
@ -28,6 +37,14 @@ class DurationPredictor(nn.Module):
self.proj = nn.Conv1d(hidden_channels, 1, 1) self.proj = nn.Conv1d(hidden_channels, 1, 1)
def forward(self, x, x_mask): def forward(self, x, x_mask):
"""
Shapes:
x: [B, C, T]
x_mask: [B, 1, T]
Returns:
[type]: [description]
"""
x = self.conv_1(x * x_mask) x = self.conv_1(x * x_mask)
x = torch.relu(x) x = torch.relu(x)
x = self.norm_1(x) x = self.norm_1(x)

View File

@ -142,6 +142,12 @@ class Encoder(nn.Module):
dropout_p_dp) dropout_p_dp)
def forward(self, x, x_lengths, g=None): def forward(self, x, x_lengths, g=None):
"""
Shapes:
x: [B, C, T]
x_lengths: [B]
g (optional): [B, 1, T]
"""
# embedding layer # embedding layer
# [B ,T, D] # [B ,T, D]
x = self.emb(x) * math.sqrt(self.hidden_channels) x = self.emb(x) * math.sqrt(self.hidden_channels)

View File

@ -7,8 +7,46 @@ from TTS.tts.layers.glow_tts.glow import LayerNorm
class RelativePositionMultiHeadAttention(nn.Module): class RelativePositionMultiHeadAttention(nn.Module):
"""Implementation of Relative Position Encoding based on """Multi-head attention with Relative Positional embedding.
https://arxiv.org/pdf/1809.04281.pdf https://arxiv.org/pdf/1809.04281.pdf
It learns positional embeddings for a window of neighbours. For keys and values,
it learns different set of embeddings. Key embeddings are agregated with the attention
scores and value embeddings are aggregated with the output.
Note:
Example with relative attention window size 2
input = [a, b, c, d, e]
rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
So it learns 4 embedding vectors (in total 8) separately for key and value vectors.
Considering the input c
e(t-2) corresponds to c -> a
e(t-2) corresponds to c -> b
e(t-2) corresponds to c -> d
e(t-2) corresponds to c -> e
These embeddings are shared among different time steps. So input a, b, d and e also uses
the same embeddings.
Embeddings are ignored when the relative window is out of limit for the first and the last
n items.
Args:
channels (int): input and inner layer channels.
out_channels (int): output channels.
num_heads (int): number of attention heads.
rel_attn_window_size (int, optional): relation attention window size.
If 4, for each time step next and previous 4 time steps are attended.
If default, relative encoding is disabled and it is a regular transformer.
Defaults to None.
heads_share (bool, optional): [description]. Defaults to True.
dropout_p (float, optional): dropout rate. Defaults to 0..
input_length (int, optional): intput length for positional encoding. Defaults to None.
proximal_bias (bool, optional): enable/disable proximal bias as in the paper. Defaults to False.
proximal_init (bool, optional): enable/disable poximal init as in the paper.
Init key and query layer weights the same. Defaults to False.
""" """
def __init__(self, def __init__(self,
channels, channels,
@ -20,6 +58,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
input_length=None, input_length=None,
proximal_bias=False, proximal_bias=False,
proximal_init=False): proximal_init=False):
super().__init__() super().__init__()
assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." assert channels % num_heads == 0, " [!] channels should be divisible by num_heads."
# class attributes # class attributes
@ -226,20 +265,28 @@ class RelativePositionMultiHeadAttention(nn.Module):
class FFN(nn.Module): class FFN(nn.Module):
"""Feed Forward Inner layers for Transformer.
Args:
in_channels (int): input tensor channels.
out_channels (int): output tensor channels.
hidden_channels (int): inner layers hidden channels.
kernel_size (int): conv1d filter kernel size.
dropout_p (float, optional): dropout rate. Defaults to 0.
"""
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
hidden_channels, hidden_channels,
kernel_size, kernel_size,
dropout_p=0., dropout_p=0.):
activation=None):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.kernel_size = kernel_size self.kernel_size = kernel_size
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.activation = activation
self.conv_1 = nn.Conv1d(in_channels, self.conv_1 = nn.Conv1d(in_channels,
hidden_channels, hidden_channels,
@ -253,16 +300,29 @@ class FFN(nn.Module):
def forward(self, x, x_mask): def forward(self, x, x_mask):
x = self.conv_1(x * x_mask) x = self.conv_1(x * x_mask)
if self.activation == "gelu": x = torch.relu(x)
x = x * torch.sigmoid(1.702 * x)
else:
x = torch.relu(x)
x = self.dropout(x) x = self.dropout(x)
x = self.conv_2(x * x_mask) x = self.conv_2(x * x_mask)
return x * x_mask return x * x_mask
class RelativePositionTransformer(nn.Module): class RelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding.
https://arxiv.org/abs/1803.02155
Args:
hidden_channels (int): model hidden channels.
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
num_heads (int): number of attention heads.
num_layers (int): number of transformer layers.
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
rel_attn_window_size (int, optional): relation attention window size.
If 4, for each time step next and previous 4 time steps are attended.
If default, relative encoding is disabled and it is a regular transformer.
Defaults to None.
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
"""
def __init__(self, def __init__(self,
hidden_channels, hidden_channels,
hidden_channels_ffn, hidden_channels_ffn,
@ -305,6 +365,11 @@ class RelativePositionTransformer(nn.Module):
self.norm_layers_2.append(LayerNorm(hidden_channels)) self.norm_layers_2.append(LayerNorm(hidden_channels))
def forward(self, x, x_mask): def forward(self, x, x_mask):
"""
Shapes:
x: [B, C, T]
x_mask: [B, 1, T]
"""
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(self.num_layers): for i in range(self.num_layers):
x = x * x_mask x = x * x_mask

View File

@ -8,6 +8,17 @@ from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
"""Load ```TTS.tts.models``` checkpoints.
Args:
model (TTS.tts.models): model object to load the weights for.
checkpoint_path (string): checkpoint file path.
amp (apex.amp, optional): Apex amp abject to load apex related state vars. Defaults to None.
use_cuda (bool, optional): load model to GPU if True. Defaults to False.
Returns:
[type]: [description]
"""
try: try:
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
except ModuleNotFoundError: except ModuleNotFoundError:
@ -26,6 +37,17 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs): def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
"""Save ```TTS.tts.models``` states with extra fields.
Args:
model (TTS.tts.models.Model): models object to be saved.
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
current_step (int): current number of training steps.
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
"""
if hasattr(model, 'module'): if hasattr(model, 'module'):
model_state = model.module.state_dict() model_state = model.module.state_dict()
else: else:
@ -45,6 +67,16 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs): def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
"""Save model checkpoint, intended for saving checkpoints at training.
Args:
model (TTS.tts.models.Model): models object to be saved.
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
current_step (int): current number of training steps.
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
"""
file_name = 'checkpoint_{}.pth.tar'.format(current_step) file_name = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(output_folder, file_name) checkpoint_path = os.path.join(output_folder, file_name)
print(" > CHECKPOINT : {}".format(checkpoint_path)) print(" > CHECKPOINT : {}".format(checkpoint_path))
@ -52,6 +84,23 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs): def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
"""Save model checkpoint, intended for saving the best model after each epoch.
It compares the current model loss with the best loss so far and saves the
model if the current loss is better.
Args:
target_loss (float): current model loss.
best_loss (float): best loss so far.
model (TTS.tts.models.Model): models object to be saved.
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
current_step (int): current number of training steps.
epoch (int): current number of training epochs.
r (int): model reduction rate for Tacotron models.
output_path (str): output path to save the model file.
Returns:
float: updated current best loss.
"""
if target_loss < best_loss: if target_loss < best_loss:
file_name = 'best_model.pth.tar' file_name = 'best_model.pth.tar'
checkpoint_path = os.path.join(output_folder, file_name) checkpoint_path = os.path.join(output_folder, file_name)