mirror of https://github.com/coqui-ai/TTS.git
Fix resnet speaker encoder
This commit is contained in:
parent
348b5c96a2
commit
36cef5966b
|
@ -198,7 +198,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
l2_norm (bool): Whether to L2-normalize the outputs.
|
l2_norm (bool): Whether to L2-normalize the outputs.
|
||||||
|
|
||||||
Shapes:
|
Shapes:
|
||||||
- x: :math:`(N, 1, T_{in})`
|
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
@ -206,8 +206,6 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
# if you torch spec compute it otherwise use the mel spec computed by the AP
|
||||||
if self.use_torch_spec:
|
if self.use_torch_spec:
|
||||||
x = self.torch_spec(x)
|
x = self.torch_spec(x)
|
||||||
else:
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
|
|
||||||
if self.log_input:
|
if self.log_input:
|
||||||
x = (x + 1e-6).log()
|
x = (x + 1e-6).log()
|
||||||
|
|
|
@ -417,8 +417,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
|
||||||
and self.config.audio["sample_rate"]
|
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
||||||
!= self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
|
|
||||||
):
|
):
|
||||||
# TODO: change this with torchaudio Resample
|
# TODO: change this with torchaudio Resample
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
|
@ -19,7 +19,7 @@ use_cuda = torch.cuda.is_available()
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
#pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
class TestVits(unittest.TestCase):
|
class TestVits(unittest.TestCase):
|
||||||
def test_init_multispeaker(self):
|
def test_init_multispeaker(self):
|
||||||
num_speakers = 10
|
num_speakers = 10
|
||||||
|
|
Loading…
Reference in New Issue