diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index f6385747..c8ad410d 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -128,8 +128,9 @@ class InvConvNear(nn.Module): return z, logdet def store_inverse(self): - self.weight_inv = torch.inverse( + weight_inv = torch.inverse( self.weight.float()).to(dtype=self.weight.dtype) + self.weight_inv = nn.Parameter(weight_inv, requires_grad=False) class CouplingBlock(nn.Module):