mirror of https://github.com/coqui-ai/TTS.git
Fix speaker encoder test
This commit is contained in:
parent
497332bd46
commit
348b5c96a2
|
@ -127,7 +127,7 @@ class LSTMSpeakerEncoder(nn.Module):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, l2_norm=True):
|
def inference(self, x, l2_norm=True):
|
||||||
d = self.layers.forward(x, l2_norm=l2_norm)
|
d = self.forward(x, l2_norm=l2_norm)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True):
|
||||||
|
|
|
@ -398,10 +398,10 @@ class Vits(BaseTTS):
|
||||||
self.num_speakers = self.speaker_manager.num_speakers
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
|
||||||
if self.args.use_speaker_embedding:
|
if self.args.use_speaker_embedding:
|
||||||
self._init_speaker_embedding(config)
|
self._init_speaker_embedding()
|
||||||
|
|
||||||
if self.args.use_d_vector_file:
|
if self.args.use_d_vector_file:
|
||||||
self._init_d_vector(config)
|
self._init_d_vector()
|
||||||
|
|
||||||
# TODO: make this a function
|
# TODO: make this a function
|
||||||
if self.args.use_speaker_encoder_as_loss:
|
if self.args.use_speaker_encoder_as_loss:
|
||||||
|
@ -436,14 +436,14 @@ class Vits(BaseTTS):
|
||||||
self.audio_transform = None
|
self.audio_transform = None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _init_speaker_embedding(self, config):
|
def _init_speaker_embedding(self):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if self.num_speakers > 0:
|
if self.num_speakers > 0:
|
||||||
print(" > initialization of speaker-embedding layers.")
|
print(" > initialization of speaker-embedding layers.")
|
||||||
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
self.embedded_speaker_dim = self.args.speaker_embedding_channels
|
||||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
|
|
||||||
def _init_d_vector(self, config):
|
def _init_d_vector(self):
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if hasattr(self, "emb_g"):
|
if hasattr(self, "emb_g"):
|
||||||
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.")
|
||||||
|
|
|
@ -270,7 +270,7 @@ class SpeakerManager:
|
||||||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||||
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
|
if not self.speaker_encoder_config.model_params.get("use_torch_spec", False):
|
||||||
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
|
m_input = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||||
m_input = torch.from_numpy(m_input.T)
|
m_input = torch.from_numpy(m_input)
|
||||||
else:
|
else:
|
||||||
m_input = torch.from_numpy(waveform)
|
m_input = torch.from_numpy(waveform)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ file_path = get_tests_input_path()
|
||||||
class LSTMSpeakerEncoderTests(unittest.TestCase):
|
class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
dummy_input = T.rand(4, 80, 20) # B x D x T
|
||||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||||
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
model = LSTMSpeakerEncoder(input_dim=80, proj_dim=256, lstm_dim=768, num_lstm_layers=3)
|
||||||
# computing d vectors
|
# computing d vectors
|
||||||
|
@ -34,7 +34,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||||
assert output.type() == "torch.FloatTensor"
|
assert output.type() == "torch.FloatTensor"
|
||||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||||
# compute d for a given batch
|
# compute d for a given batch
|
||||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
dummy_input = T.rand(1, 80, 240) # B x T x D
|
||||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
|
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=5)
|
||||||
assert output.shape[0] == 1
|
assert output.shape[0] == 1
|
||||||
assert output.shape[1] == 256
|
assert output.shape[1] == 256
|
||||||
|
@ -44,7 +44,7 @@ class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||||
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||||
# pylint: disable=R0201
|
# pylint: disable=R0201
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
dummy_input = T.rand(4, 80, 20) # B x D x T
|
||||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||||
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
|
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
|
||||||
# computing d vectors
|
# computing d vectors
|
||||||
|
@ -61,7 +61,7 @@ class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||||
assert output.type() == "torch.FloatTensor"
|
assert output.type() == "torch.FloatTensor"
|
||||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||||
# compute d for a given batch
|
# compute d for a given batch
|
||||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
dummy_input = T.rand(1, 80, 240) # B x D x T
|
||||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
|
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
|
||||||
assert output.shape[0] == 1
|
assert output.shape[0] == 1
|
||||||
assert output.shape[1] == 256
|
assert output.shape[1] == 256
|
||||||
|
|
|
@ -38,7 +38,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
||||||
# load a sample audio and compute embedding
|
# load a sample audio and compute embedding
|
||||||
waveform = ap.load_wav(sample_wav_path)
|
waveform = ap.load_wav(sample_wav_path)
|
||||||
mel = ap.melspectrogram(waveform)
|
mel = ap.melspectrogram(waveform)
|
||||||
d_vector = manager.compute_d_vector(mel.T)
|
d_vector = manager.compute_d_vector(mel)
|
||||||
assert d_vector.shape[1] == 256
|
assert d_vector.shape[1] == 256
|
||||||
|
|
||||||
# compute d_vector directly from an input file
|
# compute d_vector directly from an input file
|
||||||
|
|
Loading…
Reference in New Issue