change `to(device)` to `type_as` in models

This commit is contained in:
Eren Gölge 2021-06-01 16:26:28 +02:00
parent 9c94b0c5c0
commit f82f1970b8
5 changed files with 14 additions and 27 deletions

View File

@ -145,14 +145,13 @@ class TacotronAbstract(ABC, nn.Module):
def compute_masks(self, text_lengths, mel_lengths):
"""Compute masks against sequence paddings."""
# B x T_in_max (boolean)
device = text_lengths.device
input_mask = sequence_mask(text_lengths).to(device)
input_mask = sequence_mask(text_lengths)
output_mask = None
if mel_lengths is not None:
max_len = mel_lengths.max()
r = self.decoder.r
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device)
output_mask = sequence_mask(mel_lengths, max_len=max_len)
return input_mask, output_mask
def _backward_pass(self, mel_specs, encoder_outputs, mask):
@ -195,20 +194,19 @@ class TacotronAbstract(ABC, nn.Module):
def compute_gst(self, inputs, style_input, speaker_embedding=None):
"""Compute global style token"""
device = inputs.device
if isinstance(style_input, dict):
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).to(device)
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
for k_token, v_amplifier in style_input.items():
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
elif style_input is None:
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
else:
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs)

View File

@ -44,8 +44,7 @@ def sequence_mask(sequence_length, max_len=None):
batch_size = sequence_length.size(0)
seq_range = np.empty([0, max_len], dtype=np.int8)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_range_expand = seq_range_expand.type_as(sequence_length)
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
# B x T_max
return seq_range_expand < seq_length_expand

View File

@ -56,9 +56,6 @@ class SSIM(torch.nn.Module):
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
@ -69,10 +66,6 @@ class SSIM(torch.nn.Module):
def ssim(img1, img2, window_size=11, size_average=True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = create_window(window_size, channel).type_as(img1)
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)

View File

@ -251,7 +251,6 @@ class WaveRNN(nn.Module):
def inference(self, mels, batched=None, target=None, overlap=None):
self.eval()
device = mels.device
output = []
start = time.time()
rnn1 = self.get_gru_cell(self.rnn1)
@ -259,7 +258,7 @@ class WaveRNN(nn.Module):
with torch.no_grad():
if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels).to(device)
mels = torch.FloatTensor(mels).type_as(mels)
if mels.ndim == 2:
mels = mels.unsqueeze(0)
@ -275,9 +274,9 @@ class WaveRNN(nn.Module):
b_size, seq_len, _ = mels.size()
h1 = torch.zeros(b_size, self.rnn_dims).to(device)
h2 = torch.zeros(b_size, self.rnn_dims).to(device)
x = torch.zeros(b_size, 1).to(device)
h1 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
h2 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
x = torch.zeros(b_size, 1).type_as(mels)
if self.use_aux_net:
d = self.aux_dims
@ -310,11 +309,11 @@ class WaveRNN(nn.Module):
if self.mode == "mold":
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).to(device)
x = sample.transpose(0, 1).type_as(mels)
elif self.mode == "gauss":
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
output.append(sample.view(-1))
x = sample.transpose(0, 1).to(device)
x = sample.transpose(0, 1).type_as(mels)
elif isinstance(self.mode, int):
posterior = F.softmax(logits, dim=1)
distrib = torch.distributions.Categorical(posterior)

View File

@ -149,8 +149,6 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
def to_one_hot(tensor, n, fill_with=1.0):
# we perform one hot encore with respect to the last axis
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
if tensor.is_cuda:
one_hot = one_hot.cuda()
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor)
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
return one_hot