diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py index a83bb292..a08f64a8 100644 --- a/TTS/tts/layers/glow_tts/duration_predictor.py +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -5,6 +5,15 @@ from ..generic.normalization import LayerNorm class DurationPredictor(nn.Module): + """Glow-TTS duration prediction model. + [2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs + + Args: + in_channels ([type]): [description] + hidden_channels ([type]): [description] + kernel_size ([type]): [description] + dropout_p ([type]): [description] + """ def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): super().__init__() # class arguments @@ -28,6 +37,14 @@ class DurationPredictor(nn.Module): self.proj = nn.Conv1d(hidden_channels, 1, 1) def forward(self, x, x_mask): + """ + Shapes: + x: [B, C, T] + x_mask: [B, 1, T] + + Returns: + [type]: [description] + """ x = self.conv_1(x * x_mask) x = torch.relu(x) x = self.norm_1(x) diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index 6a5c2fad..ab7aaba5 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -142,6 +142,12 @@ class Encoder(nn.Module): dropout_p_dp) def forward(self, x, x_lengths, g=None): + """ + Shapes: + x: [B, C, T] + x_lengths: [B] + g (optional): [B, 1, T] + """ # embedding layer # [B ,T, D] x = self.emb(x) * math.sqrt(self.hidden_channels) diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 291c845a..826913c1 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -7,8 +7,46 @@ from TTS.tts.layers.glow_tts.glow import LayerNorm class RelativePositionMultiHeadAttention(nn.Module): - """Implementation of Relative Position Encoding based on + """Multi-head attention with Relative Positional embedding. https://arxiv.org/pdf/1809.04281.pdf + + It learns positional embeddings for a window of neighbours. For keys and values, + it learns different set of embeddings. Key embeddings are agregated with the attention + scores and value embeddings are aggregated with the output. + + Note: + Example with relative attention window size 2 + input = [a, b, c, d, e] + rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)] + + So it learns 4 embedding vectors (in total 8) separately for key and value vectors. + + Considering the input c + e(t-2) corresponds to c -> a + e(t-2) corresponds to c -> b + e(t-2) corresponds to c -> d + e(t-2) corresponds to c -> e + + These embeddings are shared among different time steps. So input a, b, d and e also uses + the same embeddings. + + Embeddings are ignored when the relative window is out of limit for the first and the last + n items. + + Args: + channels (int): input and inner layer channels. + out_channels (int): output channels. + num_heads (int): number of attention heads. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + heads_share (bool, optional): [description]. Defaults to True. + dropout_p (float, optional): dropout rate. Defaults to 0.. + input_length (int, optional): intput length for positional encoding. Defaults to None. + proximal_bias (bool, optional): enable/disable proximal bias as in the paper. Defaults to False. + proximal_init (bool, optional): enable/disable poximal init as in the paper. + Init key and query layer weights the same. Defaults to False. """ def __init__(self, channels, @@ -20,6 +58,7 @@ class RelativePositionMultiHeadAttention(nn.Module): input_length=None, proximal_bias=False, proximal_init=False): + super().__init__() assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." # class attributes @@ -226,20 +265,28 @@ class RelativePositionMultiHeadAttention(nn.Module): class FFN(nn.Module): + """Feed Forward Inner layers for Transformer. + + Args: + in_channels (int): input tensor channels. + out_channels (int): output tensor channels. + hidden_channels (int): inner layers hidden channels. + kernel_size (int): conv1d filter kernel size. + dropout_p (float, optional): dropout rate. Defaults to 0. + """ def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, - dropout_p=0., - activation=None): + dropout_p=0.): + super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.dropout_p = dropout_p - self.activation = activation self.conv_1 = nn.Conv1d(in_channels, hidden_channels, @@ -253,16 +300,29 @@ class FFN(nn.Module): def forward(self, x, x_mask): x = self.conv_1(x * x_mask) - if self.activation == "gelu": - x = x * torch.sigmoid(1.702 * x) - else: - x = torch.relu(x) + x = torch.relu(x) x = self.dropout(x) x = self.conv_2(x * x_mask) return x * x_mask class RelativePositionTransformer(nn.Module): + """Transformer with Relative Potional Encoding. + https://arxiv.org/abs/1803.02155 + + Args: + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. + """ def __init__(self, hidden_channels, hidden_channels_ffn, @@ -305,6 +365,11 @@ class RelativePositionTransformer(nn.Module): self.norm_layers_2.append(LayerNorm(hidden_channels)) def forward(self, x, x_mask): + """ + Shapes: + x: [B, C, T] + x_mask: [B, 1, T] + """ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) for i in range(self.num_layers): x = x * x_mask diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index f84445d9..830529a3 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -8,6 +8,17 @@ from TTS.utils.io import RenamingUnpickler def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): + """Load ```TTS.tts.models``` checkpoints. + + Args: + model (TTS.tts.models): model object to load the weights for. + checkpoint_path (string): checkpoint file path. + amp (apex.amp, optional): Apex amp abject to load apex related state vars. Defaults to None. + use_cuda (bool, optional): load model to GPU if True. Defaults to False. + + Returns: + [type]: [description] + """ try: state = torch.load(checkpoint_path, map_location=torch.device('cpu')) except ModuleNotFoundError: @@ -26,6 +37,17 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False): def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_dict=None, **kwargs): + """Save ```TTS.tts.models``` states with extra fields. + + Args: + model (TTS.tts.models.Model): models object to be saved. + optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training. + current_step (int): current number of training steps. + epoch (int): current number of training epochs. + r (int): model reduction rate for Tacotron models. + output_path (str): output path to save the model file. + amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None. + """ if hasattr(model, 'module'): model_state = model.module.state_dict() else: @@ -45,6 +67,16 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, amp_state_ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs): + """Save model checkpoint, intended for saving checkpoints at training. + + Args: + model (TTS.tts.models.Model): models object to be saved. + optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training. + current_step (int): current number of training steps. + epoch (int): current number of training epochs. + r (int): model reduction rate for Tacotron models. + output_path (str): output path to save the model file. + """ file_name = 'checkpoint_{}.pth.tar'.format(current_step) checkpoint_path = os.path.join(output_folder, file_name) print(" > CHECKPOINT : {}".format(checkpoint_path)) @@ -52,6 +84,23 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs): + """Save model checkpoint, intended for saving the best model after each epoch. + It compares the current model loss with the best loss so far and saves the + model if the current loss is better. + + Args: + target_loss (float): current model loss. + best_loss (float): best loss so far. + model (TTS.tts.models.Model): models object to be saved. + optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training. + current_step (int): current number of training steps. + epoch (int): current number of training epochs. + r (int): model reduction rate for Tacotron models. + output_path (str): output path to save the model file. + + Returns: + float: updated current best loss. + """ if target_loss < best_loss: file_name = 'best_model.pth.tar' checkpoint_path = os.path.join(output_folder, file_name)