From f82f1970b8847185783dcf2a1a295260c2863616 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 1 Jun 2021 16:26:28 +0200 Subject: [PATCH] change `to(device)` to `type_as` in models --- TTS/tts/models/tacotron_abstract.py | 12 +++++------- TTS/tts/tf/utils/generic_utils.py | 3 +-- TTS/tts/utils/ssim.py | 9 +-------- TTS/vocoder/models/wavernn.py | 13 ++++++------- TTS/vocoder/utils/distribution.py | 4 +--- 5 files changed, 14 insertions(+), 27 deletions(-) diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 5e561066..fe43d81f 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -145,14 +145,13 @@ class TacotronAbstract(ABC, nn.Module): def compute_masks(self, text_lengths, mel_lengths): """Compute masks against sequence paddings.""" # B x T_in_max (boolean) - device = text_lengths.device - input_mask = sequence_mask(text_lengths).to(device) + input_mask = sequence_mask(text_lengths) output_mask = None if mel_lengths is not None: max_len = mel_lengths.max() r = self.decoder.r max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len - output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device) + output_mask = sequence_mask(mel_lengths, max_len=max_len) return input_mask, output_mask def _backward_pass(self, mel_specs, encoder_outputs, mask): @@ -195,20 +194,19 @@ class TacotronAbstract(ABC, nn.Module): def compute_gst(self, inputs, style_input, speaker_embedding=None): """Compute global style token""" - device = inputs.device if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).to(device) + query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs) if speaker_embedding is not None: query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) _GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens) - gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) for k_token, v_amplifier in style_input.items(): key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1) gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key) gst_outputs = gst_outputs + gst_outputs_att * v_amplifier elif style_input is None: - gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device) + gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs) else: gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable inputs = self._concat_speaker_embedding(inputs, gst_outputs) diff --git a/TTS/tts/tf/utils/generic_utils.py b/TTS/tts/tf/utils/generic_utils.py index 5b8b4ce2..e76893c2 100644 --- a/TTS/tts/tf/utils/generic_utils.py +++ b/TTS/tts/tf/utils/generic_utils.py @@ -44,8 +44,7 @@ def sequence_mask(sequence_length, max_len=None): batch_size = sequence_length.size(0) seq_range = np.empty([0, max_len], dtype=np.int8) seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - if sequence_length.is_cuda: - seq_range_expand = seq_range_expand.cuda() + seq_range_expand = seq_range_expand.type_as(sequence_length) seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) # B x T_max return seq_range_expand < seq_length_expand diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 11107e47..caed575f 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -56,9 +56,6 @@ class SSIM(torch.nn.Module): window = self.window else: window = create_window(self.window_size, channel) - - if img1.is_cuda: - window = window.cuda(img1.get_device()) window = window.type_as(img1) self.window = window @@ -69,10 +66,6 @@ class SSIM(torch.nn.Module): def ssim(img1, img2, window_size=11, size_average=True): (_, channel, _, _) = img1.size() - window = create_window(window_size, channel) - - if img1.is_cuda: - window = window.cuda(img1.get_device()) + window = create_window(window_size, channel).type_as(img1) window = window.type_as(img1) - return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 994244dc..04040931 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -251,7 +251,6 @@ class WaveRNN(nn.Module): def inference(self, mels, batched=None, target=None, overlap=None): self.eval() - device = mels.device output = [] start = time.time() rnn1 = self.get_gru_cell(self.rnn1) @@ -259,7 +258,7 @@ class WaveRNN(nn.Module): with torch.no_grad(): if isinstance(mels, np.ndarray): - mels = torch.FloatTensor(mels).to(device) + mels = torch.FloatTensor(mels).type_as(mels) if mels.ndim == 2: mels = mels.unsqueeze(0) @@ -275,9 +274,9 @@ class WaveRNN(nn.Module): b_size, seq_len, _ = mels.size() - h1 = torch.zeros(b_size, self.rnn_dims).to(device) - h2 = torch.zeros(b_size, self.rnn_dims).to(device) - x = torch.zeros(b_size, 1).to(device) + h1 = torch.zeros(b_size, self.rnn_dims).type_as(mels) + h2 = torch.zeros(b_size, self.rnn_dims).type_as(mels) + x = torch.zeros(b_size, 1).type_as(mels) if self.use_aux_net: d = self.aux_dims @@ -310,11 +309,11 @@ class WaveRNN(nn.Module): if self.mode == "mold": sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) - x = sample.transpose(0, 1).to(device) + x = sample.transpose(0, 1).type_as(mels) elif self.mode == "gauss": sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) - x = sample.transpose(0, 1).to(device) + x = sample.transpose(0, 1).type_as(mels) elif isinstance(self.mode, int): posterior = F.softmax(logits, dim=1) distrib = torch.distributions.Categorical(posterior) diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py index 5c2742c8..43d0d884 100644 --- a/TTS/vocoder/utils/distribution.py +++ b/TTS/vocoder/utils/distribution.py @@ -149,8 +149,6 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None): def to_one_hot(tensor, n, fill_with=1.0): # we perform one hot encore with respect to the last axis - one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() - if tensor.is_cuda: - one_hot = one_hot.cuda() + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor) one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) return one_hot