diff --git a/TTS/tts/layers/align_tts/mdn.py b/TTS/tts/layers/align_tts/mdn.py new file mode 100644 index 00000000..32883f31 --- /dev/null +++ b/TTS/tts/layers/align_tts/mdn.py @@ -0,0 +1,25 @@ +import torch +from torch import nn +from ..generic.normalization import LayerNorm + + +class MDNBlock(nn.Module): + """Mixture of Density Network implementation + https://arxiv.org/pdf/2003.01950.pdf + """ + def __init__(self, in_channels, out_channels): + super().__init__() + self.out_channels = out_channels + self.mdn = nn.Sequential(nn.Conv1d(in_channels, in_channels, 1), + LayerNorm(in_channels), + nn.ReLU(), + nn.Dropout(0.1), + nn.Conv1d(in_channels, out_channels, 1)) + + def forward(self, x): + mu_sigma = self.mdn(x) + # TODO: check this sigmoid + # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :]) + mu = mu_sigma[:, :self.out_channels//2, :] + log_sigma = mu_sigma[:, self.out_channels//2:, :] + return mu, log_sigma