From 36cef5966b50bf5596eb78f45c48025b8114411c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 30 Dec 2021 14:41:43 +0000 Subject: [PATCH] Fix resnet speaker encoder --- TTS/speaker_encoder/models/resnet.py | 4 +--- TTS/tts/models/vits.py | 3 +-- tests/tts_tests/test_vits.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/TTS/speaker_encoder/models/resnet.py b/TTS/speaker_encoder/models/resnet.py index f1f13df1..643449c8 100644 --- a/TTS/speaker_encoder/models/resnet.py +++ b/TTS/speaker_encoder/models/resnet.py @@ -198,7 +198,7 @@ class ResNetSpeakerEncoder(nn.Module): l2_norm (bool): Whether to L2-normalize the outputs. Shapes: - - x: :math:`(N, 1, T_{in})` + - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` """ with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): @@ -206,8 +206,6 @@ class ResNetSpeakerEncoder(nn.Module): # if you torch spec compute it otherwise use the mel spec computed by the AP if self.use_torch_spec: x = self.torch_spec(x) - else: - x = x.transpose(1, 2) if self.log_input: x = (x + 1e-6).log() diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index e4e64240..8b09fdf9 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -417,8 +417,7 @@ class Vits(BaseTTS): if ( hasattr(self.speaker_manager.speaker_encoder, "audio_config") - and self.config.audio["sample_rate"] - != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] + and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] ): # TODO: change this with torchaudio Resample raise RuntimeError( diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index de075a5c..4274d947 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -19,7 +19,7 @@ use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -#pylint: disable=no-self-use +# pylint: disable=no-self-use class TestVits(unittest.TestCase): def test_init_multispeaker(self): num_speakers = 10