mirror of https://github.com/coqui-ai/TTS.git
add test for AngleProtoLoss
This commit is contained in:
parent
f3b8ef4272
commit
3c6c749de2
|
@ -4,7 +4,7 @@ import unittest
|
|||
import torch as T
|
||||
from tests import get_tests_input_path
|
||||
|
||||
from mozilla_voice_tts.speaker_encoder.loss import GE2ELoss
|
||||
from mozilla_voice_tts.speaker_encoder.losses import GE2ELoss, AngleProtoLoss
|
||||
from mozilla_voice_tts.speaker_encoder.model import SpeakerEncoder
|
||||
from mozilla_voice_tts.utils.io import load_config
|
||||
|
||||
|
@ -59,6 +59,7 @@ class GE2ELossTests(unittest.TestCase):
|
|||
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
|
||||
loss = GE2ELoss(loss_method="softmax")
|
||||
output = loss.forward(dummy_input)
|
||||
assert output.item() >= 0.0
|
||||
# check speaker loss with orthogonal d-vectors
|
||||
dummy_input = T.empty(3, 64)
|
||||
dummy_input = T.nn.init.orthogonal(dummy_input)
|
||||
|
@ -73,6 +74,34 @@ class GE2ELossTests(unittest.TestCase):
|
|||
output = loss.forward(dummy_input)
|
||||
assert output.item() < 0.005
|
||||
|
||||
class AngleProtoLossTests(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def test_in_out(self):
|
||||
# check random input
|
||||
dummy_input = T.rand(4, 5, 64) # num_speaker x num_utterance x dim
|
||||
loss = AngleProtoLoss()
|
||||
output = loss.forward(dummy_input)
|
||||
assert output.item() >= 0.0
|
||||
|
||||
# check all zeros
|
||||
dummy_input = T.ones(4, 5, 64) # num_speaker x num_utterance x dim
|
||||
loss = AngleProtoLoss()
|
||||
output = loss.forward(dummy_input)
|
||||
assert output.item() >= 0.0
|
||||
|
||||
# check speaker loss with orthogonal d-vectors
|
||||
dummy_input = T.empty(3, 64)
|
||||
dummy_input = T.nn.init.orthogonal(dummy_input)
|
||||
dummy_input = T.cat(
|
||||
[
|
||||
dummy_input[0].repeat(5, 1, 1).transpose(0, 1),
|
||||
dummy_input[1].repeat(5, 1, 1).transpose(0, 1),
|
||||
dummy_input[2].repeat(5, 1, 1).transpose(0, 1),
|
||||
]
|
||||
) # num_speaker x num_utterance x dim
|
||||
loss = AngleProtoLoss()
|
||||
output = loss.forward(dummy_input)
|
||||
assert output.item() < 0.005
|
||||
|
||||
# class LoaderTest(unittest.TestCase):
|
||||
# def test_output(self):
|
||||
|
|
Loading…
Reference in New Issue