From f7587fc1346987e2882419ded3dc8b82d12a3b39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 13 Jul 2022 10:47:12 +0200 Subject: [PATCH] Fix SSIM loss correction --- TTS/tts/layers/losses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 5d501036..5130ac0b 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -133,11 +133,11 @@ class SSIMLoss(torch.nn.Module): if ssim_loss.item() > 1.0: print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0") - ssim_loss = torch.tensor([1.0], device=ssim_loss.device) + ssim_loss = torch.tensor(1.0, device=ssim_loss.device) if ssim_loss.item() < 0.0: print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0") - ssim_loss = torch.tensor([0.0], device=ssim_loss.device) + ssim_loss = torch.tensor(0.0, device=ssim_loss.device) return ssim_loss