mirror of https://github.com/coqui-ai/TTS.git
Fix SSIM loss correction
This commit is contained in:
parent
bc1f93c299
commit
f7587fc134
|
@ -133,11 +133,11 @@ class SSIMLoss(torch.nn.Module):
|
||||||
|
|
||||||
if ssim_loss.item() > 1.0:
|
if ssim_loss.item() > 1.0:
|
||||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 1.0")
|
||||||
ssim_loss = torch.tensor([1.0], device=ssim_loss.device)
|
ssim_loss = torch.tensor(1.0, device=ssim_loss.device)
|
||||||
|
|
||||||
if ssim_loss.item() < 0.0:
|
if ssim_loss.item() < 0.0:
|
||||||
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
print(f" > SSIM loss is out-of-range {ssim_loss.item()}, setting it 0.0")
|
||||||
ssim_loss = torch.tensor([0.0], device=ssim_loss.device)
|
ssim_loss = torch.tensor(0.0, device=ssim_loss.device)
|
||||||
|
|
||||||
return ssim_loss
|
return ssim_loss
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue