mirror of https://github.com/coqui-ai/TTS.git
update glow-tts output shapes to match [B, T, C]
This commit is contained in:
parent
f568833d28
commit
506189bdee
|
@ -185,13 +185,13 @@ class GlowTTS(nn.Module):
|
|||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
outputs = {
|
||||
"model_outputs": z,
|
||||
"model_outputs": z.transpose(1, 2),
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean,
|
||||
"y_log_scale": y_log_scale,
|
||||
"y_mean": y_mean.transpose(1, 2),
|
||||
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log,
|
||||
"total_durations_log": o_attn_dur,
|
||||
"durations_log": o_dur_log.transpose(1, 2),
|
||||
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -246,13 +246,13 @@ class GlowTTS(nn.Module):
|
|||
# reverse the decoder and predict using the aligned distribution
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
outputs = {
|
||||
"model_outputs": y,
|
||||
"model_outputs": z.transpose(1, 2),
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean,
|
||||
"y_log_scale": y_log_scale,
|
||||
"y_mean": y_mean.transpose(1, 2),
|
||||
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log,
|
||||
"total_durations_log": o_attn_dur,
|
||||
"durations_log": o_dur_log.transpose(1, 2),
|
||||
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -285,7 +285,7 @@ class GlowTTS(nn.Module):
|
|||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
|
||||
outputs = {}
|
||||
outputs["model_outputs"] = y
|
||||
outputs["model_outputs"] = y.transpose(1, 2)
|
||||
outputs["logdet"] = logdet
|
||||
return outputs
|
||||
|
||||
|
@ -317,13 +317,13 @@ class GlowTTS(nn.Module):
|
|||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
outputs = {
|
||||
"model_outputs": y,
|
||||
"model_outputs": y.transpose(1, 2),
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean,
|
||||
"y_log_scale": y_log_scale,
|
||||
"y_mean": y_mean.transpose(1, 2),
|
||||
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log,
|
||||
"total_durations_log": o_attn_dur,
|
||||
"durations_log": o_dur_log.transpose(1, 2),
|
||||
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
|
Loading…
Reference in New Issue