diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 067a166f..99b8bba5 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -306,6 +306,7 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, # TODO: plot model stats # if c.tb_model_param_stats: # tb_logger.tb_model_weights(model, global_step) + torch.cuda.empty_cache() return keep_avg.avg_values, global_step @@ -433,9 +434,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) if c.print_eval: c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) - torch.cuda.empty_cache() - - if args.rank == 0: # compute spectrograms figures = plot_results(y_hat, y_G, ap, global_step, 'eval') @@ -450,7 +448,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) # synthesize a full voice data_loader.return_segments = False - + torch.cuda.empty_cache() return keep_avg.avg_values