mirror of https://github.com/coqui-ai/TTS.git
change `to(device)` to `type_as` in models
This commit is contained in:
parent
c09622459e
commit
26a3312f0d
|
@ -145,14 +145,13 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
def compute_masks(self, text_lengths, mel_lengths):
|
def compute_masks(self, text_lengths, mel_lengths):
|
||||||
"""Compute masks against sequence paddings."""
|
"""Compute masks against sequence paddings."""
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
device = text_lengths.device
|
input_mask = sequence_mask(text_lengths)
|
||||||
input_mask = sequence_mask(text_lengths).to(device)
|
|
||||||
output_mask = None
|
output_mask = None
|
||||||
if mel_lengths is not None:
|
if mel_lengths is not None:
|
||||||
max_len = mel_lengths.max()
|
max_len = mel_lengths.max()
|
||||||
r = self.decoder.r
|
r = self.decoder.r
|
||||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||||
output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device)
|
output_mask = sequence_mask(mel_lengths, max_len=max_len)
|
||||||
return input_mask, output_mask
|
return input_mask, output_mask
|
||||||
|
|
||||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||||
|
@ -195,20 +194,19 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
|
|
||||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||||
"""Compute global style token"""
|
"""Compute global style token"""
|
||||||
device = inputs.device
|
|
||||||
if isinstance(style_input, dict):
|
if isinstance(style_input, dict):
|
||||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).to(device)
|
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||||
if speaker_embedding is not None:
|
if speaker_embedding is not None:
|
||||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||||
|
|
||||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
|
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||||
for k_token, v_amplifier in style_input.items():
|
for k_token, v_amplifier in style_input.items():
|
||||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||||
elif style_input is None:
|
elif style_input is None:
|
||||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
|
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||||
else:
|
else:
|
||||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||||
|
|
|
@ -44,8 +44,7 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
batch_size = sequence_length.size(0)
|
batch_size = sequence_length.size(0)
|
||||||
seq_range = np.empty([0, max_len], dtype=np.int8)
|
seq_range = np.empty([0, max_len], dtype=np.int8)
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||||
if sequence_length.is_cuda:
|
seq_range_expand = seq_range_expand.type_as(sequence_length)
|
||||||
seq_range_expand = seq_range_expand.cuda()
|
|
||||||
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
|
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
|
||||||
# B x T_max
|
# B x T_max
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
|
@ -56,9 +56,6 @@ class SSIM(torch.nn.Module):
|
||||||
window = self.window
|
window = self.window
|
||||||
else:
|
else:
|
||||||
window = create_window(self.window_size, channel)
|
window = create_window(self.window_size, channel)
|
||||||
|
|
||||||
if img1.is_cuda:
|
|
||||||
window = window.cuda(img1.get_device())
|
|
||||||
window = window.type_as(img1)
|
window = window.type_as(img1)
|
||||||
|
|
||||||
self.window = window
|
self.window = window
|
||||||
|
@ -69,10 +66,6 @@ class SSIM(torch.nn.Module):
|
||||||
|
|
||||||
def ssim(img1, img2, window_size=11, size_average=True):
|
def ssim(img1, img2, window_size=11, size_average=True):
|
||||||
(_, channel, _, _) = img1.size()
|
(_, channel, _, _) = img1.size()
|
||||||
window = create_window(window_size, channel)
|
window = create_window(window_size, channel).type_as(img1)
|
||||||
|
|
||||||
if img1.is_cuda:
|
|
||||||
window = window.cuda(img1.get_device())
|
|
||||||
window = window.type_as(img1)
|
window = window.type_as(img1)
|
||||||
|
|
||||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
return _ssim(img1, img2, window, window_size, channel, size_average)
|
||||||
|
|
|
@ -251,7 +251,6 @@ class WaveRNN(nn.Module):
|
||||||
def inference(self, mels, batched=None, target=None, overlap=None):
|
def inference(self, mels, batched=None, target=None, overlap=None):
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
device = mels.device
|
|
||||||
output = []
|
output = []
|
||||||
start = time.time()
|
start = time.time()
|
||||||
rnn1 = self.get_gru_cell(self.rnn1)
|
rnn1 = self.get_gru_cell(self.rnn1)
|
||||||
|
@ -259,7 +258,7 @@ class WaveRNN(nn.Module):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if isinstance(mels, np.ndarray):
|
if isinstance(mels, np.ndarray):
|
||||||
mels = torch.FloatTensor(mels).to(device)
|
mels = torch.FloatTensor(mels).type_as(mels)
|
||||||
|
|
||||||
if mels.ndim == 2:
|
if mels.ndim == 2:
|
||||||
mels = mels.unsqueeze(0)
|
mels = mels.unsqueeze(0)
|
||||||
|
@ -275,9 +274,9 @@ class WaveRNN(nn.Module):
|
||||||
|
|
||||||
b_size, seq_len, _ = mels.size()
|
b_size, seq_len, _ = mels.size()
|
||||||
|
|
||||||
h1 = torch.zeros(b_size, self.rnn_dims).to(device)
|
h1 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
|
||||||
h2 = torch.zeros(b_size, self.rnn_dims).to(device)
|
h2 = torch.zeros(b_size, self.rnn_dims).type_as(mels)
|
||||||
x = torch.zeros(b_size, 1).to(device)
|
x = torch.zeros(b_size, 1).type_as(mels)
|
||||||
|
|
||||||
if self.use_aux_net:
|
if self.use_aux_net:
|
||||||
d = self.aux_dims
|
d = self.aux_dims
|
||||||
|
@ -310,11 +309,11 @@ class WaveRNN(nn.Module):
|
||||||
if self.mode == "mold":
|
if self.mode == "mold":
|
||||||
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
|
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
|
||||||
output.append(sample.view(-1))
|
output.append(sample.view(-1))
|
||||||
x = sample.transpose(0, 1).to(device)
|
x = sample.transpose(0, 1).type_as(mels)
|
||||||
elif self.mode == "gauss":
|
elif self.mode == "gauss":
|
||||||
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2))
|
||||||
output.append(sample.view(-1))
|
output.append(sample.view(-1))
|
||||||
x = sample.transpose(0, 1).to(device)
|
x = sample.transpose(0, 1).type_as(mels)
|
||||||
elif isinstance(self.mode, int):
|
elif isinstance(self.mode, int):
|
||||||
posterior = F.softmax(logits, dim=1)
|
posterior = F.softmax(logits, dim=1)
|
||||||
distrib = torch.distributions.Categorical(posterior)
|
distrib = torch.distributions.Categorical(posterior)
|
||||||
|
|
|
@ -149,8 +149,6 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
||||||
|
|
||||||
def to_one_hot(tensor, n, fill_with=1.0):
|
def to_one_hot(tensor, n, fill_with=1.0):
|
||||||
# we perform one hot encore with respect to the last axis
|
# we perform one hot encore with respect to the last axis
|
||||||
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_().type_as(tensor)
|
||||||
if tensor.is_cuda:
|
|
||||||
one_hot = one_hot.cuda()
|
|
||||||
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
||||||
return one_hot
|
return one_hot
|
||||||
|
|
Loading…
Reference in New Issue