mirror of https://github.com/coqui-ai/TTS.git
change `to(device)` to `type_as` in models
This commit is contained in:
parent
9c94b0c5c0
commit
f82f1970b8
|
@ -145,14 +145,13 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
def compute_masks(self, text_lengths, mel_lengths):
|
||||
"""Compute masks against sequence paddings."""
|
||||
# B x T_in_max (boolean)
|
||||
device = text_lengths.device
|
||||
input_mask = sequence_mask(text_lengths).to(device)
|
||||
input_mask = sequence_mask(text_lengths)
|
||||
output_mask = None
|
||||
if mel_lengths is not None:
|
||||
max_len = mel_lengths.max()
|
||||
r = self.decoder.r
|
||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device)
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len)
|
||||
return input_mask, output_mask
|
||||
|
||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
|
@ -195,20 +194,19 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
|
||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
"""Compute global style token"""
|
||||
device = inputs.device
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).to(device)
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
for k_token, v_amplifier in style_input.items():
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
|
|
|
@ -44,8 +44,7 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
batch_size = sequence_length.size(0)
|
||||
seq_range = np.empty([0, max_len], dtype=np.int8)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_range_expand = seq_range_expand.type_as(sequence_length)
|
||||
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
|
|
@ -56,9 +56,6 @@ class SSIM(torch.nn.Module):
|
|||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
|
@ -69,10 +66,6 @@ class SSIM(torch.nn.Module):
|
|||
|
||||
def ssim(img1, img2, window_size=11, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = create_window(window_size, channel).type_as(img1)
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
||||
|
|
|
@ -251,7 +251,6 @@ class WaveRNN(nn.Module):
|
|||
def inference(self, mels, batched=None, target=None, overlap=None):
|
||||
|
||||
self.eval()
|
||||
device = mels.device
|
||||
output = []
|
||||
start = time.time()
|
||||
rnn1 = self.get_gru_cell(self.rnn1)
|
||||
|
@ -259,7 +258,7 @@ class WaveRNN(nn.Module):
|
|||
|
||||
with torch.no_grad():
|
||||
if isinstance(mels, np.ndarray):
|
||||
mels = torch.FloatTensor(mels).to(device)
|
||||
mels = torch.FloatTensor(mels).type_as(mels)
|
||||
|
||||
if mels.ndim == 2:
|
||||
mels = mels.unsqueeze(0)
|
||||
|
@ -275,9 +274,9 @@ class WaveRNN(nn.Module):
|
|||
|
||||
b_size, seq_len, _ = mels.size()
|
||||
|
||||
h1 = torch.zeros(b_size, self.rnn_dims).to(device)
|
||||
h2 = torch.zeros(b_size, self.rnn_dims).to(device)
|
||||
x = torch.zeros(b_size, 1).to(device)
|
||||
h1 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
|
||||
h2 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
|
||||
x = torch.zeros(b_size, 1).type_as(mels)
|
||||
|
||||
if self.use_aux_net:
|
||||
d = self.aux_dims
|
||||
|
@ -310,11 +309,11 @@ class WaveRNN(nn.Module):
|
|||
if self.mode == "mold":
|
||||
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).to(device)
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif self.mode == "gauss":
|
||||
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
||||
output.append(sample.view(-1))
|
||||
x = sample.transpose(0, 1).to(device)
|
||||
x = sample.transpose(0, 1).type_as(mels)
|
||||
elif isinstance(self.mode, int):
|
||||
posterior = F.softmax(logits, dim=1)
|
||||
distrib = torch.distributions.Categorical(posterior)
|
||||
|
|
|
@ -149,8 +149,6 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
|||
|
||||
def to_one_hot(tensor, n, fill_with=1.0):
|
||||
# we perform one hot encore with respect to the last axis
|
||||
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
||||
if tensor.is_cuda:
|
||||
one_hot = one_hot.cuda()
|
||||
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor)
|
||||
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
||||
return one_hot
|
||||
|
|
Loading…
Reference in New Issue