mirror of https://github.com/coqui-ai/TTS.git
do not call .item() if returned loss is not torch type
This commit is contained in:
parent
8f72ad900a
commit
3c20afa1c9
|
@ -217,7 +217,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
scheduler_D.step()
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
@ -355,7 +358,10 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
|
@ -393,7 +399,10 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
|
||||
step_time = time.time() - start_time
|
||||
|
|
Loading…
Reference in New Issue