from torch import nn from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock from TTS.tts.layers.generic.wavenet import WNBlocks from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer class Decoder(nn.Module): """Decodes the expanded phoneme encoding into spectrograms Args: out_channels (int): number of output channels. in_hidden_channels (int): input and hidden channels. Model keeps the input channels for the intermediate layers. decoder_type (str): decoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. decoder_params (dict): model parameters for specified decoder type. c_in_channels (int): number of channels for conditional input. Shapes: - input: (B, C, T) Note: Default decoder_params... for 'transformer' decoder_params={ 'hidden_channels_ffn': 128, 'num_heads': 2, "kernel_size": 3, "dropout_p": 0.1, "num_layers": 8, "rel_attn_window_size": 4, "input_length": None }, for 'residual_conv_bn' decoder_params = { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 } for 'wavenet' decoder_params = { "num_blocks": 12, "hidden_channels":192, "kernel_size": 5, "dilation_rate": 1, "num_layers": 4, "dropout_p": 0.05 } """ # pylint: disable=dangerous-default-value def __init__( self, out_channels, in_hidden_channels, decoder_type='residual_conv_bn', decoder_params={ "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 }, c_in_channels=0): super().__init__() self.in_channels = in_hidden_channels self.hidden_channels = in_hidden_channels self.out_channels = out_channels if decoder_type == 'transformer': self.decoder = RelativePositionTransformer(self.hidden_channels, **decoder_params) elif decoder_type == 'residual_conv_bn': self.decoder = ResidualConvBNBlock(self.hidden_channels, **decoder_params) elif decoder_type == 'wavenet': self.decoder = WNBlocks(in_channels=self.in_channels, hidden_channels=self.hidden_channels, **decoder_params) else: raise ValueError(f'[!] Unknown decoder type - {decoder_type}') self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels, 1) self.post_net = nn.Sequential( ConvBNBlock(self.hidden_channels, decoder_params['kernel_size'], 1, num_conv_blocks=2), nn.Conv1d(self.hidden_channels, out_channels, 1), ) def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument # TODO: implement multi-speaker o = self.decoder(x, x_mask) o = self.post_conv(o) + x return self.post_net(o) * x_mask