coqui-tts/TTS/tts/layers/speedy_speech/duration_predictor.py

40 lines
997 B
Python

from torch import nn
from TTS.tts.layers.generic.res_conv_bn import ConvBN
class DurationPredictor(nn.Module):
"""Speedy Speech duration predictor model.
Predicts phoneme durations from encoder outputs.
Note:
Outputs interpreted as log(durations)
To get actual durations, do exp transformation
conv_BN_4x1 -> conv_BN_3x1 -> conv_BN_1x1 -> conv_1x1
Args:
hidden_channels (int): number of channels in the inner layers.
"""
def __init__(self, hidden_channels):
super().__init__()
self.layers = nn.ModuleList([
ConvBN(hidden_channels, 4, 1),
ConvBN(hidden_channels, 3, 1),
ConvBN(hidden_channels, 1, 1),
nn.Conv1d(hidden_channels, 1, 1)
])
def forward(self, x, x_mask):
"""
Shapes:
x: [B, C, T]
x_mask: [B, 1, T]
"""
o = x
for layer in self.layers:
o = layer(o) * x_mask
return o