from torch import nn from TTS.tts.layers.generic.res_conv_bn import ConvBN class DurationPredictor(nn.Module): """Predicts phoneme log durations based on the encoder outputs""" def __init__(self, hidden_channels): super().__init__() self.layers = nn.ModuleList([ ConvBN(hidden_channels, 4, 1), ConvBN(hidden_channels, 3, 1), ConvBN(hidden_channels, 1, 1), nn.Conv1d(hidden_channels, 1, 1) ]) def forward(self, x, x_mask): """Outputs interpreted as log(durations) To get actual durations, do exp transformation :param x: :return: """ o = x for layer in self.layers: o = layer(o) * x_mask return o