diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 638cbebc..99053dfd 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -260,17 +260,18 @@ class WaveRNN(nn.Module): x = F.relu(self.fc2(x)) return self.fc3(x) - def generate(self, mels, batched, target, overlap, use_cuda=False): + def inference(self, mels, batched, target, overlap): self.eval() - device = 'cuda' if use_cuda else 'cpu' + device = mels.device output = [] start = time.time() rnn1 = self.get_gru_cell(self.rnn1) rnn2 = self.get_gru_cell(self.rnn2) with torch.no_grad(): - mels = torch.FloatTensor(mels).unsqueeze(0).to(device) + if isinstance(mels, np.ndarray): + mels = torch.FloatTensor(mels).unsqueeze(0).to(device) #mels = torch.FloatTensor(mels).cuda().unsqueeze(0) wave_len = (mels.size(-1) - 1) * self.hop_length mels = self.pad_tensor(mels.transpose(