diff --git a/train.py b/train.py index 46e2c43f..328e81e8 100644 --- a/train.py +++ b/train.py @@ -372,10 +372,10 @@ def evaluate(model, criterion, ap, global_step, epoch): # aggregate losses from processes if num_gpus > 1: - postnet_loss = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) - decoder_loss = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) + loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) + loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) if c.stopnet: - stopnet_loss = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) + loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) keep_avg.update_values({ 'avg_postnet_loss':