diff --git a/train.py b/train.py index 383fd64d..693cef8a 100644 --- a/train.py +++ b/train.py @@ -187,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() - gt_spec = mel_input[0].data.cpu().numpy() + gt_spec = linear_input[0].data.cpu().numpy() if c.model == "Tacotron" else mel_input[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy() figures = { @@ -315,7 +315,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch): # Diagnostic visualizations idx = np.random.randint(mel_input.shape[0]) const_spec = postnet_output[idx].data.cpu().numpy() - gt_spec = mel_input[idx].data.cpu().numpy() + gt_spec = linear_input[idx].data.cpu().numpy() if c.model == "Tacotron" else mel_input[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy() eval_figures = {