mirror of https://github.com/coqui-ai/TTS.git
Fix SSIM loss correction
This commit is contained in:
parent
bc1f93c299
commit
f7587fc134
|
@ -133,11 +133,11 @@ class SSIMLoss(torch.nn.Module):
|
|||
|
||||
if ssim_loss.item() > 1.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
||||
ssim_loss = torch.tensor([1.0], device=ssim_loss.device)
|
||||
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
|
||||
|
||||
if ssim_loss.item() < 0.0:
|
||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||
ssim_loss = torch.tensor([0.0], device=ssim_loss.device)
|
||||
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
|
||||
|
||||
return ssim_loss
|
||||
|
||||
|
|
Loading…
Reference in New Issue