From a3213762ae9deabcc6dc6e0ca6af4acfcd89718b Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 26 Oct 2020 17:23:28 +0100 Subject: [PATCH] update wavegrad tests --- tests/test_wavegrad_layers.py | 80 +++++++++++++++++++++++++++++++++++ tests/test_wavegrad_train.py | 57 +++++++++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 tests/test_wavegrad_layers.py create mode 100644 tests/test_wavegrad_train.py diff --git a/tests/test_wavegrad_layers.py b/tests/test_wavegrad_layers.py new file mode 100644 index 00000000..a1c6a7e5 --- /dev/null +++ b/tests/test_wavegrad_layers.py @@ -0,0 +1,80 @@ +import torch + +from TTS.vocoder.layers.wavegrad import PositionalEncoding, FiLM, UBlock, DBlock +from TTS.vocoder.models.wavegrad import Wavegrad + + +def test_positional_encoding(): + layer = PositionalEncoding(50) + inp = torch.rand(32, 50, 100) + nl = torch.rand(32) + o = layer(inp, nl) + + assert o.shape[0] == 32 + assert o.shape[1] == 50 + assert o.shape[2] == 100 + assert isinstance(o, torch.FloatTensor) + + +def test_film(): + layer = FiLM(50, 76) + inp = torch.rand(32, 50, 100) + nl = torch.rand(32) + shift, scale = layer(inp, nl) + + assert shift.shape[0] == 32 + assert shift.shape[1] == 76 + assert shift.shape[2] == 100 + assert isinstance(shift, torch.FloatTensor) + + assert scale.shape[0] == 32 + assert scale.shape[1] == 76 + assert scale.shape[2] == 100 + assert isinstance(scale, torch.FloatTensor) + + +def test_ublock(): + inp1 = torch.rand(32, 50, 100) + inp2 = torch.rand(32, 50, 50) + nl = torch.rand(32) + + layer_film = FiLM(50, 100) + layer = UBlock(50, 100, 2, [1, 2, 4, 8]) + + scale, shift = layer_film(inp1, nl) + o = layer(inp2, shift, scale) + + assert o.shape[0] == 32 + assert o.shape[1] == 100 + assert o.shape[2] == 100 + assert isinstance(o, torch.FloatTensor) + + +def test_dblock(): + inp = torch.rand(32, 50, 130) + layer = DBlock(50, 100, 2) + o = layer(inp) + + assert o.shape[0] == 32 + assert o.shape[1] == 100 + assert o.shape[2] == 65 + assert isinstance(o, torch.FloatTensor) + + +def test_wavegrad_forward(): + x = torch.rand(32, 1, 20 * 300) + c = torch.rand(32, 80, 20) + noise_scale = torch.rand(32) + + model = Wavegrad(in_channels=80, + out_channels=1, + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], + [1, 2, 4, 8], [1, 2, 4, 8], + [1, 2, 4, 8]]) + o = model.forward(x, c, noise_scale) + + assert o.shape[0] == 32 + assert o.shape[1] == 1 + assert o.shape[2] == 20 * 300 + assert isinstance(o, torch.FloatTensor) diff --git a/tests/test_wavegrad_train.py b/tests/test_wavegrad_train.py new file mode 100644 index 00000000..1fd1d10e --- /dev/null +++ b/tests/test_wavegrad_train.py @@ -0,0 +1,57 @@ +import copy +import os +import unittest + +import torch +from tests import get_tests_input_path +from torch import nn, optim + +from TTS.vocoder.models.wavegrad import Wavegrad +from TTS.utils.io import load_config +from TTS.utils.audio import AudioProcessor + +#pylint: disable=unused-variable + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +class WavegradTrainTest(unittest.TestCase): + def test_train_step(self): # pylint: disable=no-self-use + """Test if all layers are updated in a basic training cycle""" + input_dummy = torch.rand(8, 1, 20 * 300).to(device) + mel_spec = torch.rand(8, 80, 20).to(device) + + criterion = torch.nn.L1Loss().to(device) + model = Wavegrad(in_channels=80, + out_channels=1, + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], + [1, 2, 4, 8], [1, 2, 4, 8], + [1, 2, 4, 8]]) + model.train() + model.to(device) + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=0.001) + for i in range(5): + y_hat = model.forward(input_dummy, mel_spec, torch.rand(8).to(device)) + optimizer.zero_grad() + loss = criterion(y_hat, input_dummy) + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 \ No newline at end of file