mirror of https://github.com/coqui-ai/TTS.git
add unit tests for SoftmaxAngleProtoLoss and ResnetSpeakerEncoder and bugfix
This commit is contained in:
parent
7a9a27282a
commit
bc5307caa0
|
@ -126,7 +126,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
nn.init.xavier_normal_(out)
|
||||
return out
|
||||
|
||||
def forward(self, x, training=True):
|
||||
def forward(self, x, l2_norm=False):
|
||||
x = x.transpose(1, 2)
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
|
@ -157,7 +157,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
x = x.view(x.size()[0], -1)
|
||||
x = self.fc(x)
|
||||
|
||||
if not training:
|
||||
if l2_norm:
|
||||
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
||||
return x
|
||||
|
||||
|
@ -179,7 +179,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
|||
offset = int(offset)
|
||||
end_offset = int(offset+num_frames)
|
||||
frames = x[:, offset:end_offset]
|
||||
embed = self.forward(frames, training=False)
|
||||
embed = self.forward(frames, l2_norm=True)
|
||||
embeddings.append(embed)
|
||||
|
||||
embeddings = torch.stack(embeddings)
|
||||
|
|
|
@ -46,6 +46,7 @@
|
|||
"batch_size": 32,
|
||||
"output_path": "", // DATASET-RELATED: output path for all training outputs.
|
||||
"model_params": {
|
||||
"model_name": "lstm",
|
||||
"input_dim": 40,
|
||||
"proj_dim": 256,
|
||||
"lstm_dim": 768,
|
||||
|
@ -54,8 +55,7 @@
|
|||
},
|
||||
"storage": {
|
||||
"sample_from_storage_p": 0.66, // the probability with which we'll sample from the DataSet in-memory storage
|
||||
"storage_size": 15, // the size of the in-memory storage with respect to a single batch
|
||||
"additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness
|
||||
"storage_size": 15 // the size of the in-memory storage with respect to a single batch
|
||||
},
|
||||
"datasets":null
|
||||
}
|
|
@ -3,13 +3,13 @@ import unittest
|
|||
import torch as T
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder
|
||||
# from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder
|
||||
file_path = get_tests_input_path()
|
||||
|
||||
|
||||
class SpeakerEncoderTests(unittest.TestCase):
|
||||
class LSTMSpeakerEncoderTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
||||
|
@ -39,6 +39,31 @@ class SpeakerEncoderTests(unittest.TestCase):
|
|||
assert output.shape[1] == 256
|
||||
assert len(output.shape) == 2
|
||||
|
||||
class ResNetSpeakerEncoderTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
dummy_input = T.rand(4, 20, 80) # B x T x D
|
||||
dummy_hidden = [T.rand(2, 4, 128), T.rand(2, 4, 128)]
|
||||
model = ResNetSpeakerEncoder(input_dim=80, proj_dim=256)
|
||||
# computing d vectors
|
||||
output = model.forward(dummy_input)
|
||||
assert output.shape[0] == 4
|
||||
assert output.shape[1] == 256
|
||||
output = model.forward(dummy_input, l2_norm=True)
|
||||
assert output.shape[0] == 4
|
||||
assert output.shape[1] == 256
|
||||
|
||||
# check normalization
|
||||
output_norm = T.nn.functional.normalize(output, dim=1, p=2)
|
||||
assert_diff = (output_norm - output).sum().item()
|
||||
assert output.type() == "torch.FloatTensor"
|
||||
assert abs(assert_diff) < 1e-4, f" [!] output_norm has wrong values - {assert_diff}"
|
||||
# compute d for a given batch
|
||||
dummy_input = T.rand(1, 240, 80) # B x T x D
|
||||
output = model.compute_embedding(dummy_input, num_frames=160, num_eval=10)
|
||||
assert output.shape[0] == 1
|
||||
assert output.shape[1] == 256
|
||||
assert len(output.shape) == 2
|
||||
|
||||
class GE2ELossTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
|
@ -67,7 +92,6 @@ class GE2ELossTests(unittest.TestCase):
|
|||
output = loss.forward(dummy_input)
|
||||
assert output.item() < 0.005
|
||||
|
||||
|
||||
class AngleProtoLossTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
|
@ -96,3 +120,24 @@ class AngleProtoLossTests(unittest.TestCase):
|
|||
loss = AngleProtoLoss()
|
||||
output = loss.forward(dummy_input)
|
||||
assert output.item() < 0.005
|
||||
|
||||
class SoftmaxAngleProtoLossTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
|
||||
embedding_dim = 64
|
||||
num_speakers = 5
|
||||
batch_size = 4
|
||||
|
||||
dummy_label = T.randint(low=0, high=num_speakers, size=(batch_size, num_speakers))
|
||||
# check random input
|
||||
dummy_input = T.rand(batch_size, num_speakers, embedding_dim) # num_speaker x num_utterance x dim
|
||||
loss = SoftmaxAngleProtoLoss(embedding_dim=embedding_dim, n_speakers=num_speakers)
|
||||
output = loss.forward(dummy_input, dummy_label)
|
||||
assert output.item() >= 0.0
|
||||
|
||||
# check all zeros
|
||||
dummy_input = T.ones(batch_size, num_speakers, embedding_dim) # num_speaker x num_utterance x dim
|
||||
loss = SoftmaxAngleProtoLoss(embedding_dim=embedding_dim, n_speakers=num_speakers)
|
||||
output = loss.forward(dummy_input, dummy_label)
|
||||
assert output.item() >= 0.0
|
||||
|
|
|
@ -19,7 +19,7 @@ config = SpeakerEncoderConfig(
|
|||
print_step=1,
|
||||
save_step=1,
|
||||
print_eval=True,
|
||||
audio=BaseAudioConfig(num_mels=40),
|
||||
audio=BaseAudioConfig(num_mels=80),
|
||||
)
|
||||
config.audio.do_trim_silence = True
|
||||
config.audio.trim_db = 60
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.speaker_encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -28,7 +28,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
config.audio.resample = True
|
||||
|
||||
# create a dummy speaker encoder
|
||||
model = SpeakerEncoder(**config.model_params)
|
||||
model = setup_model(config)
|
||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||
|
||||
# load audio processor and speaker encoder
|
||||
|
|
Loading…
Reference in New Issue