diff --git a/tests/test_vocoder_losses.py b/tests/test_vocoder_losses.py index 965e68ad..d578a130 100644 --- a/tests/test_vocoder_losses.py +++ b/tests/test_vocoder_losses.py @@ -5,7 +5,7 @@ from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config -from TTS.vocoder.layers.losses import MultiScaleSTFTLoss, STFTLoss, TorchSTFT +from TTS.vocoder.layers.losses import MultiScaleSTFTLoss, STFTLoss, TorchSTFT, MelganFeatureLoss TESTS_PATH = get_tests_path() @@ -52,3 +52,41 @@ def test_multiscale_stft_loss(): loss_m, loss_sc = stft_loss(wav, torch.rand_like(wav)) assert loss_sc < 1.0 assert loss_m + loss_sc > 0 + +def test_melgan_feature_loss(): + feats_real = [] + feats_fake = [] + + # if all the features are different. + for _ in range(5): # different scales + scale_feats_real = [] + scale_feats_fake = [] + for _ in range(4): # different layers + scale_feats_real.append(torch.rand([3, 5, 7])) + scale_feats_fake.append(torch.rand([3, 5, 7])) + feats_real.append(scale_feats_real) + feats_fake.append(scale_feats_fake) + + loss_func = MelganFeatureLoss() + loss = loss_func(feats_fake, feats_real) + assert loss.item() <= 1.0 + + + feats_real = [] + feats_fake = [] + + # if all the features are the same + for _ in range(5): # different scales + scale_feats_real = [] + scale_feats_fake = [] + for _ in range(4): # different layers + tensor = torch.rand([3, 5, 7]) + scale_feats_real.append(tensor) + scale_feats_fake.append(tensor) + feats_real.append(scale_feats_real) + feats_fake.append(scale_feats_fake) + + loss_func = MelganFeatureLoss() + loss = loss_func(feats_fake, feats_real) + assert loss.item() == 0 +