mirror of https://github.com/coqui-ai/TTS.git
test melgan feature loss
This commit is contained in:
parent
c8953f4da9
commit
a669a492c6
|
@ -5,7 +5,7 @@ from tests import get_tests_input_path, get_tests_output_path, get_tests_path
|
||||||
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.io import load_config
|
from TTS.utils.io import load_config
|
||||||
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss, STFTLoss, TorchSTFT
|
from TTS.vocoder.layers.losses import MultiScaleSTFTLoss, STFTLoss, TorchSTFT, MelganFeatureLoss
|
||||||
|
|
||||||
TESTS_PATH = get_tests_path()
|
TESTS_PATH = get_tests_path()
|
||||||
|
|
||||||
|
@ -52,3 +52,41 @@ def test_multiscale_stft_loss():
|
||||||
loss_m, loss_sc = stft_loss(wav, torch.rand_like(wav))
|
loss_m, loss_sc = stft_loss(wav, torch.rand_like(wav))
|
||||||
assert loss_sc < 1.0
|
assert loss_sc < 1.0
|
||||||
assert loss_m + loss_sc > 0
|
assert loss_m + loss_sc > 0
|
||||||
|
|
||||||
|
def test_melgan_feature_loss():
|
||||||
|
feats_real = []
|
||||||
|
feats_fake = []
|
||||||
|
|
||||||
|
# if all the features are different.
|
||||||
|
for _ in range(5): # different scales
|
||||||
|
scale_feats_real = []
|
||||||
|
scale_feats_fake = []
|
||||||
|
for _ in range(4): # different layers
|
||||||
|
scale_feats_real.append(torch.rand([3, 5, 7]))
|
||||||
|
scale_feats_fake.append(torch.rand([3, 5, 7]))
|
||||||
|
feats_real.append(scale_feats_real)
|
||||||
|
feats_fake.append(scale_feats_fake)
|
||||||
|
|
||||||
|
loss_func = MelganFeatureLoss()
|
||||||
|
loss = loss_func(feats_fake, feats_real)
|
||||||
|
assert loss.item() <= 1.0
|
||||||
|
|
||||||
|
|
||||||
|
feats_real = []
|
||||||
|
feats_fake = []
|
||||||
|
|
||||||
|
# if all the features are the same
|
||||||
|
for _ in range(5): # different scales
|
||||||
|
scale_feats_real = []
|
||||||
|
scale_feats_fake = []
|
||||||
|
for _ in range(4): # different layers
|
||||||
|
tensor = torch.rand([3, 5, 7])
|
||||||
|
scale_feats_real.append(tensor)
|
||||||
|
scale_feats_fake.append(tensor)
|
||||||
|
feats_real.append(scale_feats_real)
|
||||||
|
feats_fake.append(scale_feats_fake)
|
||||||
|
|
||||||
|
loss_func = MelganFeatureLoss()
|
||||||
|
loss = loss_func(feats_fake, feats_real)
|
||||||
|
assert loss.item() == 0
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue