mirror of https://github.com/coqui-ai/TTS.git
Use `torch.linalg.qr` for pytorch > `v1.9.0`
This commit is contained in:
parent
0a1962b583
commit
d42d1c02ea
|
@ -1,3 +1,5 @@
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
@ -81,7 +83,11 @@ class InvConvNear(nn.Module):
|
||||||
self.no_jacobian = no_jacobian
|
self.no_jacobian = no_jacobian
|
||||||
self.weight_inv = None
|
self.weight_inv = None
|
||||||
|
|
||||||
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
if LooseVersion(torch.__version__) < LooseVersion("1.9"):
|
||||||
|
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||||
|
else:
|
||||||
|
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
|
||||||
|
|
||||||
if torch.det(w_init) < 0:
|
if torch.det(w_init) < 0:
|
||||||
w_init[:, 0] = -1 * w_init[:, 0]
|
w_init[:, 0] = -1 * w_init[:, 0]
|
||||||
self.weight = nn.Parameter(w_init)
|
self.weight = nn.Parameter(w_init)
|
||||||
|
|
Loading…
Reference in New Issue