mirror of https://github.com/coqui-ai/TTS.git
fix glow-tts inference and forward functions for handling `cond_input`
and refactor its test
This commit is contained in:
parent
f840268181
commit
6c495c6a6e
|
@ -154,10 +154,10 @@ class GlowTTS(nn.Module):
|
||||||
y_lengths: B
|
y_lengths: B
|
||||||
g: [B, C] or B
|
g: [B, C] or B
|
||||||
"""
|
"""
|
||||||
y_max_length = y.size(2)
|
|
||||||
y = y.transpose(1, 2)
|
y = y.transpose(1, 2)
|
||||||
|
y_max_length = y.size(2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
g = cond_input["x_vectors"]
|
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.speaker_embedding_dim:
|
if self.speaker_embedding_dim:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
|
@ -196,19 +196,23 @@ class GlowTTS(nn.Module):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
def inference_with_MAS(
|
||||||
|
self, x, x_lengths, y=None, y_lengths=None, cond_input={"x_vectors": None}
|
||||||
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
It's similar to the teacher forcing in Tacotron.
|
It's similar to the teacher forcing in Tacotron.
|
||||||
It was proposed in: https://arxiv.org/abs/2104.05557
|
It was proposed in: https://arxiv.org/abs/2104.05557
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T]
|
x: [B, T]
|
||||||
x_lenghts: B
|
x_lenghts: B
|
||||||
y: [B, C, T]
|
y: [B, T, C]
|
||||||
y_lengths: B
|
y_lengths: B
|
||||||
g: [B, C] or B
|
g: [B, C] or B
|
||||||
"""
|
"""
|
||||||
|
y = y.transpose(1, 2)
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
|
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.external_speaker_embedding_dim:
|
if self.external_speaker_embedding_dim:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
|
@ -253,14 +257,18 @@ class GlowTTS(nn.Module):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decoder_inference(self, y, y_lengths=None, g=None):
|
def decoder_inference(
|
||||||
|
self, y, y_lengths=None, cond_input={"x_vectors": None}
|
||||||
|
): # pylint: disable=dangerous-default-value
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
y: [B, C, T]
|
y: [B, T, C]
|
||||||
y_lengths: B
|
y_lengths: B
|
||||||
g: [B, C] or B
|
g: [B, C] or B
|
||||||
"""
|
"""
|
||||||
|
y = y.transpose(1, 2)
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
|
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.external_speaker_embedding_dim:
|
if self.external_speaker_embedding_dim:
|
||||||
|
@ -276,10 +284,14 @@ class GlowTTS(nn.Module):
|
||||||
# reverse decoder and predict
|
# reverse decoder and predict
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
|
|
||||||
return y, logdet
|
outputs = {}
|
||||||
|
outputs["model_outputs"] = y
|
||||||
|
outputs["logdet"] = logdet
|
||||||
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, x_lengths, g=None):
|
def inference(self, x, x_lengths, cond_input={"x_vectors": None}): # pylint: disable=dangerous-default-value
|
||||||
|
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.speaker_embedding_dim:
|
if self.speaker_embedding_dim:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
|
|
|
@ -34,7 +34,7 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
|
@ -114,10 +114,17 @@ class GlowTTSTrainTest(unittest.TestCase):
|
||||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||||
input_dummy, input_lengths, mel_spec, mel_lengths, None
|
loss_dict = criterion(
|
||||||
|
outputs["model_outputs"],
|
||||||
|
outputs["y_mean"],
|
||||||
|
outputs["y_log_scale"],
|
||||||
|
outputs["logdet"],
|
||||||
|
mel_lengths,
|
||||||
|
outputs["durations_log"],
|
||||||
|
outputs["total_durations_log"],
|
||||||
|
input_lengths,
|
||||||
)
|
)
|
||||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, input_lengths)
|
|
||||||
loss = loss_dict["loss"]
|
loss = loss_dict["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -137,7 +144,7 @@ class GlowTTSInferenceTest(unittest.TestCase):
|
||||||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||||
input_lengths[-1] = 128
|
input_lengths[-1] = 128
|
||||||
mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device)
|
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||||
|
|
||||||
|
@ -175,12 +182,12 @@ class GlowTTSInferenceTest(unittest.TestCase):
|
||||||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||||
|
|
||||||
# inference encoder and decoder with MAS
|
# inference encoder and decoder with MAS
|
||||||
y, *_ = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||||
|
|
||||||
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths)
|
y2 = model.decoder_inference(mel_spec, mel_lengths)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
y_dec.shape == y.shape
|
y2["model_outputs"].shape == y["model_outputs"].shape
|
||||||
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
||||||
y.shape, y_dec.shape
|
y["model_outputs"].shape, y2["model_outputs"].shape
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue