diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2c944008..af52ba1c 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -185,13 +185,13 @@ class GlowTTS(nn.Module): y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { - "model_outputs": z, + "model_outputs": z.transpose(1, 2), "logdet": logdet, - "y_mean": y_mean, - "y_log_scale": y_log_scale, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, - "durations_log": o_dur_log, - "total_durations_log": o_attn_dur, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs @@ -246,13 +246,13 @@ class GlowTTS(nn.Module): # reverse the decoder and predict using the aligned distribution y, logdet = self.decoder(z, y_mask, g=g, reverse=True) outputs = { - "model_outputs": y, + "model_outputs": z.transpose(1, 2), "logdet": logdet, - "y_mean": y_mean, - "y_log_scale": y_log_scale, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, - "durations_log": o_dur_log, - "total_durations_log": o_attn_dur, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs @@ -285,7 +285,7 @@ class GlowTTS(nn.Module): y, logdet = self.decoder(z, y_mask, g=g, reverse=True) outputs = {} - outputs["model_outputs"] = y + outputs["model_outputs"] = y.transpose(1, 2) outputs["logdet"] = logdet return outputs @@ -317,13 +317,13 @@ class GlowTTS(nn.Module): y, logdet = self.decoder(z, y_mask, g=g, reverse=True) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { - "model_outputs": y, + "model_outputs": y.transpose(1, 2), "logdet": logdet, - "y_mean": y_mean, - "y_log_scale": y_log_scale, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, - "durations_log": o_dur_log, - "total_durations_log": o_attn_dur, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs