From 4ef083f0f11fa08c41d727c93d46118b3759ac9a Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 30 Dec 2020 14:18:31 +0100 Subject: [PATCH] select decoder type for SS --- TTS/tts/layers/glow_tts/glow.py | 2 +- TTS/tts/layers/speedy_speech/decoder.py | 56 +++++++++++++++++++++---- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index 7b394e43..eba593dc 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -84,7 +84,7 @@ class WN(torch.nn.Module): self.res_skip_layers = torch.nn.ModuleList() self.dropout = nn.Dropout(dropout_p) - if c_in_channels != 0: + if c_in_channels > 0: cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1) self.cond_layer = torch.nn.utils.weight_norm(cond_layer, diff --git a/TTS/tts/layers/speedy_speech/decoder.py b/TTS/tts/layers/speedy_speech/decoder.py index 81c2d86d..0b928de7 100644 --- a/TTS/tts/layers/speedy_speech/decoder.py +++ b/TTS/tts/layers/speedy_speech/decoder.py @@ -1,32 +1,72 @@ from torch import nn from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock +from TTS.tts.layers.glow_tts.transformer import Transformer class Decoder(nn.Module): """Decodes the expanded phoneme encoding into spectrograms + Args: + out_channels (int): number of output channels. + in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. + decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. + decoder_params (dict): model parameters for specified decoder type. + c_in_channels (int): number of channels for conditional input. + Shapes: - input: (B, C, T) + + Note: + Default decoder_params... + + for 'transformer' + encoder_params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 8, + "rel_attn_window_size": 4, + "input_length": None + }, + + for 'residual_conv_bn' + encoder_params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + } """ # pylint: disable=dangerous-default-value def __init__( self, out_channels, - hidden_channels, - residual_conv_bn_params={ + in_hidden_channels, + decoder_type='residual_conv_bn', + decoder_params={ "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 - }): + }, + c_in_channels=0): super().__init__() + self.in_channels = in_hidden_channels + self.hidden_channels = in_hidden_channels + self.out_channels = out_channels - self.decoder = ResidualConvBNBlock(hidden_channels, - **residual_conv_bn_params) + if decoder_type == 'transformer': + self.decoder = Transformer(self.hidden_channels, **decoder_params) + elif decoder_type == 'residual_conv_bn': + self.decoder = ResidualConvBNBlock(self.hidden_channels, + **decoder_params) + else: + raise ValueError(f'[!] Unknown decoder type - {decoder_type}') - self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) + self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels, 1) self.post_net = nn.Sequential( - ConvBNBlock(hidden_channels, residual_conv_bn_params['kernel_size'], 1, num_conv_blocks=2), - nn.Conv1d(hidden_channels, out_channels, 1), + ConvBNBlock(self.hidden_channels, decoder_params['kernel_size'], 1, num_conv_blocks=2), + nn.Conv1d(self.hidden_channels, out_channels, 1), ) def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument