From 25551c463401018000f1ae9e5c9ba11c6845cc2c Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 12 Nov 2020 12:52:52 +0100 Subject: [PATCH] change wavernn generate to inference --- TTS/vocoder/models/wavernn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 638cbebc..99053dfd 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -260,17 +260,18 @@ class WaveRNN(nn.Module): x = F.relu(self.fc2(x)) return self.fc3(x) - def generate(self, mels, batched, target, overlap, use_cuda=False): + def inference(self, mels, batched, target, overlap): self.eval() - device = 'cuda' if use_cuda else 'cpu' + device = mels.device output = [] start = time.time() rnn1 = self.get_gru_cell(self.rnn1) rnn2 = self.get_gru_cell(self.rnn2) with torch.no_grad(): - mels = torch.FloatTensor(mels).unsqueeze(0).to(device) + if isinstance(mels, np.ndarray): + mels = torch.FloatTensor(mels).unsqueeze(0).to(device) #mels = torch.FloatTensor(mels).cuda().unsqueeze(0) wave_len = (mels.size(-1) - 1) * self.hop_length mels = self.pad_tensor(mels.transpose(