diff --git a/TTS/encoder/models/resnet.py b/TTS/encoder/models/resnet.py index e75ab6c4..5eafcd60 100644 --- a/TTS/encoder/models/resnet.py +++ b/TTS/encoder/models/resnet.py @@ -161,16 +161,14 @@ class ResNetSpeakerEncoder(BaseEncoder): Shapes: - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` """ - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - x.squeeze_(1) - # if you torch spec compute it otherwise use the mel spec computed by the AP - if self.use_torch_spec: - x = self.torch_spec(x) + x.squeeze_(1) + # if you torch spec compute it otherwise use the mel spec computed by the AP + if self.use_torch_spec: + x = self.torch_spec(x) - if self.log_input: - x = (x + 1e-6).log() - x = self.instancenorm(x).unsqueeze(1) + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) x = self.conv1(x) x = self.relu(x)