From b22b7620c3dd515efc4daf2f0787b8bbe86b00d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 31 May 2021 15:42:07 +0200 Subject: [PATCH] update glow-tts output shapes to match [B, T, C] --- TTS/tts/models/glow_tts.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2c944008..af52ba1c 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -185,13 +185,13 @@ class GlowTTS(nn.Module): y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { - "model_outputs": z, + "model_outputs": z.transpose(1, 2), "logdet": logdet, - "y_mean": y_mean, - "y_log_scale": y_log_scale, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, - "durations_log": o_dur_log, - "total_durations_log": o_attn_dur, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs @@ -246,13 +246,13 @@ class GlowTTS(nn.Module): # reverse the decoder and predict using the aligned distribution y, logdet = self.decoder(z, y_mask, g=g, reverse=True) outputs = { - "model_outputs": y, + "model_outputs": z.transpose(1, 2), "logdet": logdet, - "y_mean": y_mean, - "y_log_scale": y_log_scale, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, - "durations_log": o_dur_log, - "total_durations_log": o_attn_dur, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs @@ -285,7 +285,7 @@ class GlowTTS(nn.Module): y, logdet = self.decoder(z, y_mask, g=g, reverse=True) outputs = {} - outputs["model_outputs"] = y + outputs["model_outputs"] = y.transpose(1, 2) outputs["logdet"] = logdet return outputs @@ -317,13 +317,13 @@ class GlowTTS(nn.Module): y, logdet = self.decoder(z, y_mask, g=g, reverse=True) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { - "model_outputs": y, + "model_outputs": y.transpose(1, 2), "logdet": logdet, - "y_mean": y_mean, - "y_log_scale": y_log_scale, + "y_mean": y_mean.transpose(1, 2), + "y_log_scale": y_log_scale.transpose(1, 2), "alignments": attn, - "durations_log": o_dur_log, - "total_durations_log": o_attn_dur, + "durations_log": o_dur_log.transpose(1, 2), + "total_durations_log": o_attn_dur.transpose(1, 2), } return outputs