Fix resnet speaker encoder

This commit is contained in:
Eren Gölge 2021-12-30 14:41:43 +00:00
parent 348b5c96a2
commit 36cef5966b
3 changed files with 3 additions and 6 deletions

View File

@ -198,7 +198,7 @@ class ResNetSpeakerEncoder(nn.Module):
l2_norm (bool): Whether to L2-normalize the outputs. l2_norm (bool): Whether to L2-normalize the outputs.
Shapes: Shapes:
- x: :math:`(N, 1, T_{in})` - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
""" """
with torch.no_grad(): with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
@ -206,8 +206,6 @@ class ResNetSpeakerEncoder(nn.Module):
# if you torch spec compute it otherwise use the mel spec computed by the AP # if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec: if self.use_torch_spec:
x = self.torch_spec(x) x = self.torch_spec(x)
else:
x = x.transpose(1, 2)
if self.log_input: if self.log_input:
x = (x + 1e-6).log() x = (x + 1e-6).log()

View File

@ -417,8 +417,7 @@ class Vits(BaseTTS):
if ( if (
hasattr(self.speaker_manager.speaker_encoder, "audio_config") hasattr(self.speaker_manager.speaker_encoder, "audio_config")
and self.config.audio["sample_rate"] and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
!= self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
): ):
# TODO: change this with torchaudio Resample # TODO: change this with torchaudio Resample
raise RuntimeError( raise RuntimeError(

View File

@ -19,7 +19,7 @@ use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#pylint: disable=no-self-use # pylint: disable=no-self-use
class TestVits(unittest.TestCase): class TestVits(unittest.TestCase):
def test_init_multispeaker(self): def test_init_multispeaker(self):
num_speakers = 10 num_speakers = 10