diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 723f18dd..18476798 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -17,6 +17,7 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import mulaw_decode from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.layers.losses import WaveRNNLoss +from TTS.vocoder.layers.upsample import Stretch2d from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian @@ -66,19 +67,6 @@ class MelResNet(nn.Module): return x -class Stretch2d(nn.Module): - def __init__(self, x_scale, y_scale): - super().__init__() - self.x_scale = x_scale - self.y_scale = y_scale - - def forward(self, x): - b, c, h, w = x.size() - x = x.unsqueeze(-1).unsqueeze(3) - x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) - return x.view(b, c, h * self.y_scale, w * self.x_scale) - - class UpsampleNetwork(nn.Module): def __init__( self,