mirror of https://github.com/coqui-ai/TTS.git
add test for AngleProtoLoss
This commit is contained in:
parent
f3b8ef4272
commit
3c6c749de2
|
@ -4,7 +4,7 @@ import unittest
|
||||||
import torch as T
|
import torch as T
|
||||||
from tests import get_tests_input_path
|
from tests import get_tests_input_path
|
||||||
|
|
||||||
from mozilla_voice_tts.speaker_encoder.loss import GE2ELoss
|
from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss, AngleProtoLoss
|
||||||
from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder
|
from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder
|
||||||
from mozilla_voice_tts.utils.io import load_config
|
from mozilla_voice_tts.utils.io import load_config
|
||||||
|
|
||||||
|
@ -59,6 +59,7 @@ class GE2ELossTests(unittest.TestCase):
|
||||||
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
|
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
|
||||||
loss = GE2ELoss(loss_method="softmax")
|
loss = GE2ELoss(loss_method="softmax")
|
||||||
output = loss.forward(dummy_input)
|
output = loss.forward(dummy_input)
|
||||||
|
assert output.item() >= 0.0
|
||||||
# check speaker loss with orthogonal d-vectors
|
# check speaker loss with orthogonal d-vectors
|
||||||
dummy_input = T.empty(3, 64)
|
dummy_input = T.empty(3, 64)
|
||||||
dummy_input = T.nn.init.orthogonal(dummy_input)
|
dummy_input = T.nn.init.orthogonal(dummy_input)
|
||||||
|
@ -73,6 +74,34 @@ class GE2ELossTests(unittest.TestCase):
|
||||||
output = loss.forward(dummy_input)
|
output = loss.forward(dummy_input)
|
||||||
assert output.item() < 0.005
|
assert output.item() < 0.005
|
||||||
|
|
||||||
|
class AngleProtoLossTests(unittest.TestCase):
|
||||||
|
# pylint: disable=R0201
|
||||||
|
def test_in_out(self):
|
||||||
|
# check random input
|
||||||
|
dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim
|
||||||
|
loss = AngleProtoLoss()
|
||||||
|
output = loss.forward(dummy_input)
|
||||||
|
assert output.item() >= 0.0
|
||||||
|
|
||||||
|
# check all zeros
|
||||||
|
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
|
||||||
|
loss = AngleProtoLoss()
|
||||||
|
output = loss.forward(dummy_input)
|
||||||
|
assert output.item() >= 0.0
|
||||||
|
|
||||||
|
# check speaker loss with orthogonal d-vectors
|
||||||
|
dummy_input = T.empty(3, 64)
|
||||||
|
dummy_input = T.nn.init.orthogonal(dummy_input)
|
||||||
|
dummy_input = T.cat(
|
||||||
|
[
|
||||||
|
dummy_input[0].repeat(5, 1, 1).transpose(0, 1),
|
||||||
|
dummy_input[1].repeat(5, 1, 1).transpose(0, 1),
|
||||||
|
dummy_input[2].repeat(5, 1, 1).transpose(0, 1),
|
||||||
|
]
|
||||||
|
) # num_speaker x num_utterance x dim
|
||||||
|
loss = AngleProtoLoss()
|
||||||
|
output = loss.forward(dummy_input)
|
||||||
|
assert output.item() < 0.005
|
||||||
|
|
||||||
# class LoaderTest(unittest.TestCase):
|
# class LoaderTest(unittest.TestCase):
|
||||||
# def test_output(self):
|
# def test_output(self):
|
||||||
|
|
Loading…
Reference in New Issue