mirror of https://github.com/coqui-ai/TTS.git
update glow-tts output shapes to match [B, T, C]
This commit is contained in:
parent
8381379938
commit
b22b7620c3
|
@ -185,13 +185,13 @@ class GlowTTS(nn.Module):
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
outputs = {
|
outputs = {
|
||||||
"model_outputs": z,
|
"model_outputs": z.transpose(1, 2),
|
||||||
"logdet": logdet,
|
"logdet": logdet,
|
||||||
"y_mean": y_mean,
|
"y_mean": y_mean.transpose(1, 2),
|
||||||
"y_log_scale": y_log_scale,
|
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||||
"alignments": attn,
|
"alignments": attn,
|
||||||
"durations_log": o_dur_log,
|
"durations_log": o_dur_log.transpose(1, 2),
|
||||||
"total_durations_log": o_attn_dur,
|
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -246,13 +246,13 @@ class GlowTTS(nn.Module):
|
||||||
# reverse the decoder and predict using the aligned distribution
|
# reverse the decoder and predict using the aligned distribution
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
outputs = {
|
outputs = {
|
||||||
"model_outputs": y,
|
"model_outputs": z.transpose(1, 2),
|
||||||
"logdet": logdet,
|
"logdet": logdet,
|
||||||
"y_mean": y_mean,
|
"y_mean": y_mean.transpose(1, 2),
|
||||||
"y_log_scale": y_log_scale,
|
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||||
"alignments": attn,
|
"alignments": attn,
|
||||||
"durations_log": o_dur_log,
|
"durations_log": o_dur_log.transpose(1, 2),
|
||||||
"total_durations_log": o_attn_dur,
|
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -285,7 +285,7 @@ class GlowTTS(nn.Module):
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
outputs["model_outputs"] = y
|
outputs["model_outputs"] = y.transpose(1, 2)
|
||||||
outputs["logdet"] = logdet
|
outputs["logdet"] = logdet
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -317,13 +317,13 @@ class GlowTTS(nn.Module):
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
outputs = {
|
outputs = {
|
||||||
"model_outputs": y,
|
"model_outputs": y.transpose(1, 2),
|
||||||
"logdet": logdet,
|
"logdet": logdet,
|
||||||
"y_mean": y_mean,
|
"y_mean": y_mean.transpose(1, 2),
|
||||||
"y_log_scale": y_log_scale,
|
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||||
"alignments": attn,
|
"alignments": attn,
|
||||||
"durations_log": o_dur_log,
|
"durations_log": o_dur_log.transpose(1, 2),
|
||||||
"total_durations_log": o_attn_dur,
|
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||||
}
|
}
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue