mirror of https://github.com/coqui-ai/TTS.git
Fix speaker encoder test
This commit is contained in:
parent
497332bd46
commit
348b5c96a2
|
@ -127,7 +127,7 @@ class LSTMSpeakerEncoder(nn.Module):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, l2_norm=True):
|
||||
d = self.layers.forward(x, l2_norm=l2_norm)
|
||||
d = self.forward(x, l2_norm=l2_norm)
|
||||
return d
|
||||
|
||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||
|
|
|
@ -398,10 +398,10 @@ class Vits(BaseTTS):
|
|||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
|
||||
if self.args.use_speaker_embedding:
|
||||
self._init_speaker_embedding(config)
|
||||
self._init_speaker_embedding()
|
||||
|
||||
if self.args.use_d_vector_file:
|
||||
self._init_d_vector(config)
|
||||
self._init_d_vector()
|
||||
|
||||
# TODO: make this a function
|
||||
if self.args.use_speaker_encoder_as_loss:
|
||||
|
@ -436,14 +436,14 @@ class Vits(BaseTTS):
|
|||
self.audio_transform = None
|
||||
"""
|
||||
|
||||
def _init_speaker_embedding(self, config):
|
||||
def _init_speaker_embedding(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if self.num_speakers > 0:
|
||||
print(" > initialization of speaker-embedding layers.")
|
||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
|
||||
def _init_d_vector(self, config):
|
||||
def _init_d_vector(self):
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
if hasattr(self, "emb_g"):
|
||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||
|
|
|
@ -270,7 +270,7 @@ class SpeakerManager:
|
|||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
|
||||
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
m_input = torch.from_numpy(m_input.T)
|
||||
m_input = torch.from_numpy(m_input)
|
||||
else:
|
||||
m_input = torch.from_numpy(waveform)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ file_path = get_tests_input_path()
|
|||
class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
||||
dummy_input = T.rand(4, 80, 20) # B x D x T
|
||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
||||
# computing d vectors
|
||||
|
@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
|||
assert output.type() == "torch.FloatTensor"
|
||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||
# compute d for a given batch
|
||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
||||
dummy_input = T.rand(1, 80, 240) # B x T x D
|
||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
|
||||
assert output.shape[0] == 1
|
||||
assert output.shape[1] == 256
|
||||
|
@ -44,7 +44,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
|||
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
||||
dummy_input = T.rand(4, 80, 20) # B x D x T
|
||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
|
||||
# computing d vectors
|
||||
|
@ -61,7 +61,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
|
|||
assert output.type() == "torch.FloatTensor"
|
||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||
# compute d for a given batch
|
||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
||||
dummy_input = T.rand(1, 80, 240) # B x D x T
|
||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
|
||||
assert output.shape[0] == 1
|
||||
assert output.shape[1] == 256
|
||||
|
|
|
@ -38,7 +38,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
# load a sample audio and compute embedding
|
||||
waveform = ap.load_wav(sample_wav_path)
|
||||
mel = ap.melspectrogram(waveform)
|
||||
d_vector = manager.compute_d_vector(mel.T)
|
||||
d_vector = manager.compute_d_vector(mel)
|
||||
assert d_vector.shape[1] == 256
|
||||
|
||||
# compute d_vector directly from an input file
|
||||
|
|
Loading…
Reference in New Issue