mirror of https://github.com/coqui-ai/TTS.git
doc strings for the all glow-tts layers
This commit is contained in:
parent
d3b7284be4
commit
e7fad928e7
|
@ -5,6 +5,15 @@ from ..generic.normalization import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
class DurationPredictor(nn.Module):
|
class DurationPredictor(nn.Module):
|
||||||
|
"""Glow-TTS duration prediction model.
|
||||||
|
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels ([type]): [description]
|
||||||
|
hidden_channels ([type]): [description]
|
||||||
|
kernel_size ([type]): [description]
|
||||||
|
dropout_p ([type]): [description]
|
||||||
|
"""
|
||||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# class arguments
|
# class arguments
|
||||||
|
@ -28,6 +37,14 @@ class DurationPredictor(nn.Module):
|
||||||
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
x: [B, C, T]
|
||||||
|
x_mask: [B, 1, T]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type]: [description]
|
||||||
|
"""
|
||||||
x = self.conv_1(x * x_mask)
|
x = self.conv_1(x * x_mask)
|
||||||
x = torch.relu(x)
|
x = torch.relu(x)
|
||||||
x = self.norm_1(x)
|
x = self.norm_1(x)
|
||||||
|
|
|
@ -142,6 +142,12 @@ class Encoder(nn.Module):
|
||||||
dropout_p_dp)
|
dropout_p_dp)
|
||||||
|
|
||||||
def forward(self, x, x_lengths, g=None):
|
def forward(self, x, x_lengths, g=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
x: [B, C, T]
|
||||||
|
x_lengths: [B]
|
||||||
|
g (optional): [B, 1, T]
|
||||||
|
"""
|
||||||
# embedding layer
|
# embedding layer
|
||||||
# [B ,T, D]
|
# [B ,T, D]
|
||||||
x = self.emb(x) * math.sqrt(self.hidden_channels)
|
x = self.emb(x) * math.sqrt(self.hidden_channels)
|
||||||
|
|
|
@ -7,8 +7,46 @@ from TTS.tts.layers.glow_tts.glow import LayerNorm
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionMultiHeadAttention(nn.Module):
|
class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
"""Implementation of Relative Position Encoding based on
|
"""Multi-head attention with Relative Positional embedding.
|
||||||
https://arxiv.org/pdf/1809.04281.pdf
|
https://arxiv.org/pdf/1809.04281.pdf
|
||||||
|
|
||||||
|
It learns positional embeddings for a window of neighbours. For keys and values,
|
||||||
|
it learns different set of embeddings. Key embeddings are agregated with the attention
|
||||||
|
scores and value embeddings are aggregated with the output.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Example with relative attention window size 2
|
||||||
|
input = [a, b, c, d, e]
|
||||||
|
rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
|
||||||
|
|
||||||
|
So it learns 4 embedding vectors (in total 8) separately for key and value vectors.
|
||||||
|
|
||||||
|
Considering the input c
|
||||||
|
e(t-2) corresponds to c -> a
|
||||||
|
e(t-2) corresponds to c -> b
|
||||||
|
e(t-2) corresponds to c -> d
|
||||||
|
e(t-2) corresponds to c -> e
|
||||||
|
|
||||||
|
These embeddings are shared among different time steps. So input a, b, d and e also uses
|
||||||
|
the same embeddings.
|
||||||
|
|
||||||
|
Embeddings are ignored when the relative window is out of limit for the first and the last
|
||||||
|
n items.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): input and inner layer channels.
|
||||||
|
out_channels (int): output channels.
|
||||||
|
num_heads (int): number of attention heads.
|
||||||
|
rel_attn_window_size (int, optional): relation attention window size.
|
||||||
|
If 4, for each time step next and previous 4 time steps are attended.
|
||||||
|
If default, relative encoding is disabled and it is a regular transformer.
|
||||||
|
Defaults to None.
|
||||||
|
heads_share (bool, optional): [description]. Defaults to True.
|
||||||
|
dropout_p (float, optional): dropout rate. Defaults to 0..
|
||||||
|
input_length (int, optional): intput length for positional encoding. Defaults to None.
|
||||||
|
proximal_bias (bool, optional): enable/disable proximal bias as in the paper. Defaults to False.
|
||||||
|
proximal_init (bool, optional): enable/disable poximal init as in the paper.
|
||||||
|
Init key and query layer weights the same. Defaults to False.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
channels,
|
channels,
|
||||||
|
@ -20,6 +58,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
input_length=None,
|
input_length=None,
|
||||||
proximal_bias=False,
|
proximal_bias=False,
|
||||||
proximal_init=False):
|
proximal_init=False):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert channels % num_heads == 0, " [!] channels should be divisible by num_heads."
|
assert channels % num_heads == 0, " [!] channels should be divisible by num_heads."
|
||||||
# class attributes
|
# class attributes
|
||||||
|
@ -226,20 +265,28 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FFN(nn.Module):
|
class FFN(nn.Module):
|
||||||
|
"""Feed Forward Inner layers for Transformer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): input tensor channels.
|
||||||
|
out_channels (int): output tensor channels.
|
||||||
|
hidden_channels (int): inner layers hidden channels.
|
||||||
|
kernel_size (int): conv1d filter kernel size.
|
||||||
|
dropout_p (float, optional): dropout rate. Defaults to 0.
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
dropout_p=0.,
|
dropout_p=0.):
|
||||||
activation=None):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.hidden_channels = hidden_channels
|
self.hidden_channels = hidden_channels
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.dropout_p = dropout_p
|
self.dropout_p = dropout_p
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
self.conv_1 = nn.Conv1d(in_channels,
|
self.conv_1 = nn.Conv1d(in_channels,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
|
@ -253,16 +300,29 @@ class FFN(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
x = self.conv_1(x * x_mask)
|
x = self.conv_1(x * x_mask)
|
||||||
if self.activation == "gelu":
|
x = torch.relu(x)
|
||||||
x = x * torch.sigmoid(1.702 * x)
|
|
||||||
else:
|
|
||||||
x = torch.relu(x)
|
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
x = self.conv_2(x * x_mask)
|
x = self.conv_2(x * x_mask)
|
||||||
return x * x_mask
|
return x * x_mask
|
||||||
|
|
||||||
|
|
||||||
class RelativePositionTransformer(nn.Module):
|
class RelativePositionTransformer(nn.Module):
|
||||||
|
"""Transformer with Relative Potional Encoding.
|
||||||
|
https://arxiv.org/abs/1803.02155
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_channels (int): model hidden channels.
|
||||||
|
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
|
||||||
|
num_heads (int): number of attention heads.
|
||||||
|
num_layers (int): number of transformer layers.
|
||||||
|
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
|
||||||
|
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
|
||||||
|
rel_attn_window_size (int, optional): relation attention window size.
|
||||||
|
If 4, for each time step next and previous 4 time steps are attended.
|
||||||
|
If default, relative encoding is disabled and it is a regular transformer.
|
||||||
|
Defaults to None.
|
||||||
|
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
hidden_channels,
|
hidden_channels,
|
||||||
hidden_channels_ffn,
|
hidden_channels_ffn,
|
||||||
|
@ -305,6 +365,11 @@ class RelativePositionTransformer(nn.Module):
|
||||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||||
|
|
||||||
def forward(self, x, x_mask):
|
def forward(self, x, x_mask):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
x: [B, C, T]
|
||||||
|
x_mask: [B, 1, T]
|
||||||
|
"""
|
||||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
x = x * x_mask
|
x = x * x_mask
|
||||||
|
|
|
@ -8,6 +8,17 @@ from TTS.utils.io import RenamingUnpickler
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||||
|
"""Load ```TTS.tts.models``` checkpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TTS.tts.models): model object to load the weights for.
|
||||||
|
checkpoint_path (string): checkpoint file path.
|
||||||
|
amp (apex.amp, optional): Apex amp abject to load apex related state vars. Defaults to None.
|
||||||
|
use_cuda (bool, optional): load model to GPU if True. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[type]: [description]
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
|
@ -26,6 +37,17 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False):
|
||||||
|
|
||||||
|
|
||||||
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs):
|
||||||
|
"""Save ```TTS.tts.models``` states with extra fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TTS.tts.models.Model): models object to be saved.
|
||||||
|
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||||
|
current_step (int): current number of training steps.
|
||||||
|
epoch (int): current number of training epochs.
|
||||||
|
r (int): model reduction rate for Tacotron models.
|
||||||
|
output_path (str): output path to save the model file.
|
||||||
|
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
|
||||||
|
"""
|
||||||
if hasattr(model, 'module'):
|
if hasattr(model, 'module'):
|
||||||
model_state = model.module.state_dict()
|
model_state = model.module.state_dict()
|
||||||
else:
|
else:
|
||||||
|
@ -45,6 +67,16 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||||
|
"""Save model checkpoint, intended for saving checkpoints at training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (TTS.tts.models.Model): models object to be saved.
|
||||||
|
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||||
|
current_step (int): current number of training steps.
|
||||||
|
epoch (int): current number of training epochs.
|
||||||
|
r (int): model reduction rate for Tacotron models.
|
||||||
|
output_path (str): output path to save the model file.
|
||||||
|
"""
|
||||||
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
|
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||||
checkpoint_path = os.path.join(output_folder, file_name)
|
checkpoint_path = os.path.join(output_folder, file_name)
|
||||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||||
|
@ -52,6 +84,23 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
|
||||||
|
|
||||||
|
|
||||||
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
||||||
|
"""Save model checkpoint, intended for saving the best model after each epoch.
|
||||||
|
It compares the current model loss with the best loss so far and saves the
|
||||||
|
model if the current loss is better.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_loss (float): current model loss.
|
||||||
|
best_loss (float): best loss so far.
|
||||||
|
model (TTS.tts.models.Model): models object to be saved.
|
||||||
|
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||||
|
current_step (int): current number of training steps.
|
||||||
|
epoch (int): current number of training epochs.
|
||||||
|
r (int): model reduction rate for Tacotron models.
|
||||||
|
output_path (str): output path to save the model file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: updated current best loss.
|
||||||
|
"""
|
||||||
if target_loss < best_loss:
|
if target_loss < best_loss:
|
||||||
file_name = 'best_model.pth.tar'
|
file_name = 'best_model.pth.tar'
|
||||||
checkpoint_path = os.path.join(output_folder, file_name)
|
checkpoint_path = os.path.join(output_folder, file_name)
|
||||||
|
|
Loading…
Reference in New Issue