Fix speaker encoder test

This commit is contained in:
Eren Gölge 2021-12-30 12:36:30 +00:00
parent 497332bd46
commit 348b5c96a2
5 changed files with 11 additions and 11 deletions

View File

@ -127,7 +127,7 @@ class LSTMSpeakerEncoder(nn.Module):
@torch.no_grad() @torch.no_grad()
def inference(self, x, l2_norm=True): def inference(self, x, l2_norm=True):
d = self.layers.forward(x, l2_norm=l2_norm) d = self.forward(x, l2_norm=l2_norm)
return d return d
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):

View File

@ -398,10 +398,10 @@ class Vits(BaseTTS):
self.num_speakers = self.speaker_manager.num_speakers self.num_speakers = self.speaker_manager.num_speakers
if self.args.use_speaker_embedding: if self.args.use_speaker_embedding:
self._init_speaker_embedding(config) self._init_speaker_embedding()
if self.args.use_d_vector_file: if self.args.use_d_vector_file:
self._init_d_vector(config) self._init_d_vector()
# TODO: make this a function # TODO: make this a function
if self.args.use_speaker_encoder_as_loss: if self.args.use_speaker_encoder_as_loss:
@ -436,14 +436,14 @@ class Vits(BaseTTS):
self.audio_transform = None self.audio_transform = None
""" """
def _init_speaker_embedding(self, config): def _init_speaker_embedding(self):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if self.num_speakers > 0: if self.num_speakers > 0:
print(" > initialization of speaker-embedding layers.") print(" > initialization of speaker-embedding layers.")
self.embedded_speaker_dim = self.args.speaker_embedding_channels self.embedded_speaker_dim = self.args.speaker_embedding_channels
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
def _init_d_vector(self, config): def _init_d_vector(self):
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
if hasattr(self, "emb_g"): if hasattr(self, "emb_g"):
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")

View File

@ -270,7 +270,7 @@ class SpeakerManager:
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate) waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False): if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
m_input = self.speaker_encoder_ap.melspectrogram(waveform) m_input = self.speaker_encoder_ap.melspectrogram(waveform)
m_input = torch.from_numpy(m_input.T) m_input = torch.from_numpy(m_input)
else: else:
m_input = torch.from_numpy(waveform) m_input = torch.from_numpy(waveform)

View File

@ -13,7 +13,7 @@ file_path = get_tests_input_path()
class LSTMSpeakerEncoderTests(unittest.TestCase): class LSTMSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201 # pylint: disable=R0201
def test_in_out(self): def test_in_out(self):
dummy_input = T.rand(4, 20, 80) # B x T x D dummy_input = T.rand(4, 80, 20) # B x D x T
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)] dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3) model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
# computing d vectors # computing d vectors
@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
assert output.type() == "torch.FloatTensor" assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}" assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch # compute d for a given batch
dummy_input = T.rand(1, 240, 80) # B x T x D dummy_input = T.rand(1, 80, 240) # B x T x D
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5) output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
assert output.shape[0] == 1 assert output.shape[0] == 1
assert output.shape[1] == 256 assert output.shape[1] == 256
@ -44,7 +44,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
class ResNetSpeakerEncoderTests(unittest.TestCase): class ResNetSpeakerEncoderTests(unittest.TestCase):
# pylint: disable=R0201 # pylint: disable=R0201
def test_in_out(self): def test_in_out(self):
dummy_input = T.rand(4, 20, 80) # B x T x D dummy_input = T.rand(4, 80, 20) # B x D x T
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)] dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256) model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
# computing d vectors # computing d vectors
@ -61,7 +61,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
assert output.type() == "torch.FloatTensor" assert output.type() == "torch.FloatTensor"
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}" assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
# compute d for a given batch # compute d for a given batch
dummy_input = T.rand(1, 240, 80) # B x T x D dummy_input = T.rand(1, 80, 240) # B x D x T
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10) output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
assert output.shape[0] == 1 assert output.shape[0] == 1
assert output.shape[1] == 256 assert output.shape[1] == 256

View File

@ -38,7 +38,7 @@ class SpeakerManagerTest(unittest.TestCase):
# load a sample audio and compute embedding # load a sample audio and compute embedding
waveform = ap.load_wav(sample_wav_path) waveform = ap.load_wav(sample_wav_path)
mel = ap.melspectrogram(waveform) mel = ap.melspectrogram(waveform)
d_vector = manager.compute_d_vector(mel.T) d_vector = manager.compute_d_vector(mel)
assert d_vector.shape[1] == 256 assert d_vector.shape[1] == 256
# compute d_vector directly from an input file # compute d_vector directly from an input file