diff --git a/tests/test_encoder.py b/tests/test_encoder.py index 711ad195..46266f29 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -4,7 +4,7 @@ import unittest import torch as T from tests import get_tests_input_path -from mozilla_voice_tts.speaker_encoder.loss import GE2ELoss +from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss, AngleProtoLoss from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder from mozilla_voice_tts.utils.io import load_config @@ -59,6 +59,7 @@ class GE2ELossTests(unittest.TestCase): dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim loss = GE2ELoss(loss_method="softmax") output = loss.forward(dummy_input) + assert output.item() >= 0.0 # check speaker loss with orthogonal d-vectors dummy_input = T.empty(3, 64) dummy_input = T.nn.init.orthogonal(dummy_input) @@ -73,6 +74,34 @@ class GE2ELossTests(unittest.TestCase): output = loss.forward(dummy_input) assert output.item() < 0.005 +class AngleProtoLossTests(unittest.TestCase): + # pylint: disable=R0201 + def test_in_out(self): + # check random input + dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim + loss = AngleProtoLoss() + output = loss.forward(dummy_input) + assert output.item() >= 0.0 + + # check all zeros + dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim + loss = AngleProtoLoss() + output = loss.forward(dummy_input) + assert output.item() >= 0.0 + + # check speaker loss with orthogonal d-vectors + dummy_input = T.empty(3, 64) + dummy_input = T.nn.init.orthogonal(dummy_input) + dummy_input = T.cat( + [ + dummy_input[0].repeat(5, 1, 1).transpose(0, 1), + dummy_input[1].repeat(5, 1, 1).transpose(0, 1), + dummy_input[2].repeat(5, 1, 1).transpose(0, 1), + ] + ) # num_speaker x num_utterance x dim + loss = AngleProtoLoss() + output = loss.forward(dummy_input) + assert output.item() < 0.005 # class LoaderTest(unittest.TestCase): # def test_output(self):