update glow-tts output shapes to match [B, T, C]

This commit is contained in:
Eren Gölge 2021-05-31 15:42:07 +02:00
parent f568833d28
commit 506189bdee
1 changed files with 16 additions and 16 deletions

View File

@ -185,13 +185,13 @@ class GlowTTS(nn.Module):
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
attn = attn.squeeze(1).permute(0, 2, 1)
outputs = {
"model_outputs": z,
"model_outputs": z.transpose(1, 2),
"logdet": logdet,
"y_mean": y_mean,
"y_log_scale": y_log_scale,
"y_mean": y_mean.transpose(1, 2),
"y_log_scale": y_log_scale.transpose(1, 2),
"alignments": attn,
"durations_log": o_dur_log,
"total_durations_log": o_attn_dur,
"durations_log": o_dur_log.transpose(1, 2),
"total_durations_log": o_attn_dur.transpose(1, 2),
}
return outputs
@ -246,13 +246,13 @@ class GlowTTS(nn.Module):
# reverse the decoder and predict using the aligned distribution
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
outputs = {
"model_outputs": y,
"model_outputs": z.transpose(1, 2),
"logdet": logdet,
"y_mean": y_mean,
"y_log_scale": y_log_scale,
"y_mean": y_mean.transpose(1, 2),
"y_log_scale": y_log_scale.transpose(1, 2),
"alignments": attn,
"durations_log": o_dur_log,
"total_durations_log": o_attn_dur,
"durations_log": o_dur_log.transpose(1, 2),
"total_durations_log": o_attn_dur.transpose(1, 2),
}
return outputs
@ -285,7 +285,7 @@ class GlowTTS(nn.Module):
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
outputs = {}
outputs["model_outputs"] = y
outputs["model_outputs"] = y.transpose(1, 2)
outputs["logdet"] = logdet
return outputs
@ -317,13 +317,13 @@ class GlowTTS(nn.Module):
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
attn = attn.squeeze(1).permute(0, 2, 1)
outputs = {
"model_outputs": y,
"model_outputs": y.transpose(1, 2),
"logdet": logdet,
"y_mean": y_mean,
"y_log_scale": y_log_scale,
"y_mean": y_mean.transpose(1, 2),
"y_log_scale": y_log_scale.transpose(1, 2),
"alignments": attn,
"durations_log": o_dur_log,
"total_durations_log": o_attn_dur,
"durations_log": o_dur_log.transpose(1, 2),
"total_durations_log": o_attn_dur.transpose(1, 2),
}
return outputs