diff --git a/vocoder/train.py b/vocoder/train.py index 326bd90e..41b3c1ec 100644 --- a/vocoder/train.py +++ b/vocoder/train.py @@ -217,7 +217,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, scheduler_D.step() for key, value in loss_D_dict.items(): - loss_dict[key] = value.item() + if isinstance(value, (int, float)): + loss_dict[key] = value + else: + loss_dict[key] = value.item() step_time = time.time() - start_time epoch_time += step_time @@ -355,7 +358,10 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) loss_dict = dict() for key, value in loss_G_dict.items(): - loss_dict[key] = value.item() + if isinstance(value, (int, float)): + loss_dict[key] = value + else: + loss_dict[key] = value.item() ############################## # DISCRIMINATOR @@ -393,7 +399,10 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) loss_D_dict = criterion_D(scores_fake, scores_real) for key, value in loss_D_dict.items(): - loss_dict[key] = value.item() + if isinstance(value, (int, float)): + loss_dict[key] = value + else: + loss_dict[key] = value.item() step_time = time.time() - start_time