glow-tts fix for saving inverse weight

This commit is contained in:
root 2021-01-20 02:09:42 +00:00
parent 3d30dae8f3
commit 5c87753e88
1 changed files with 2 additions and 1 deletions

View File

@ -128,8 +128,9 @@ class InvConvNear(nn.Module):
return z, logdet
def store_inverse(self):
self.weight_inv = torch.inverse(
weight_inv = torch.inverse(
self.weight.float()).to(dtype=self.weight.dtype)
self.weight_inv = nn.Parameter(weight_inv, requires_grad=False)
class CouplingBlock(nn.Module):