coqui-tts/TTS/tts/layers/speedy_speech/duration_predictor.py

28 lines
761 B
Python

from torch import nn
from TTS.tts.layers.generic.res_conv_bn import ConvBN
class DurationPredictor(nn.Module):
"""Predicts phoneme log durations based on the encoder outputs"""
def __init__(self, hidden_channels):
super().__init__()
self.layers = nn.ModuleList([
ConvBN(hidden_channels, 4, 1),
ConvBN(hidden_channels, 3, 1),
ConvBN(hidden_channels, 1, 1),
nn.Conv1d(hidden_channels, 1, 1)
])
def forward(self, x, x_mask):
"""Outputs interpreted as log(durations)
To get actual durations, do exp transformation
:param x:
:return:
"""
o = x
for layer in self.layers:
o = layer(o) * x_mask
return o