Fix SSIM loss correction

This commit is contained in:
Eren Gölge 2022-07-13 10:47:12 +02:00 committed by GitHub
parent bc1f93c299
commit f7587fc134
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -133,11 +133,11 @@ class SSIMLoss(torch.nn.Module):
if ssim_loss.item() > 1.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
ssim_loss = torch.tensor([1.0], device=ssim_loss.device)
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
if ssim_loss.item() < 0.0:
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
ssim_loss = torch.tensor([0.0], device=ssim_loss.device)
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
return ssim_loss