From ad3863df3cb84b201befd7bc09a1dcf931482a61 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 30 Dec 2020 14:24:04 +0100 Subject: [PATCH] update SS tests for multi-speaker case --- tests/test_speedy_speech_layers.py | 74 ++++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 19 deletions(-) diff --git a/tests/test_speedy_speech_layers.py b/tests/test_speedy_speech_layers.py index b9787009..33a5e615 100644 --- a/tests/test_speedy_speech_layers.py +++ b/tests/test_speedy_speech_layers.py @@ -15,31 +15,31 @@ def test_encoder(): input_dummy = torch.rand(8, 14, 37).to(device) input_lengths = torch.randint(31, 37, (8, )).long().to(device) input_lengths[-1] = 37 - input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), - 1).to(device) + input_mask = torch.unsqueeze( + sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) # residual bn conv encoder layer = Encoder(out_channels=11, - hidden_channels=14, - encoder_type='residual_conv_bn').to(device) + in_hidden_channels=14, + encoder_type='residual_conv_bn').to(device) output = layer(input_dummy, input_mask) - assert list(output.shape)==[8, 11, 37] + assert list(output.shape) == [8, 11, 37] # transformer encoder layer = Encoder(out_channels=11, - hidden_channels=14, + in_hidden_channels=14, encoder_type='transformer', encoder_params={ - 'hidden_channels_ffn': 768, - 'num_heads': 2, - "kernel_size": 3, - "dropout_p": 0.1, - "num_layers": 6, - "rel_attn_window_size": 4, - "input_length": None - }).to(device) + 'hidden_channels_ffn': 768, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "rel_attn_window_size": 4, + "input_length": None + }).to(device) output = layer(input_dummy, input_mask) - assert list(output.shape)==[8, 11, 37] + assert list(output.shape) == [8, 11, 37] def test_decoder(): @@ -47,9 +47,10 @@ def test_decoder(): input_lengths = torch.randint(31, 37, (8, )).long().to(device) input_lengths[-1] = 37 - input_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) + input_mask = torch.unsqueeze( + sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) - layer = Decoder(out_channels=11, hidden_channels=128).to(device) + layer = Decoder(out_channels=11, in_hidden_channels=128).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] @@ -59,12 +60,13 @@ def test_duration_predictor(): input_lengths = torch.randint(20, 27, (8, )).long().to(device) input_lengths[-1] = 27 - x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) + x_mask = torch.unsqueeze(sequence_mask(input_lengths, input_dummy.size(2)), + 1).to(device) layer = DurationPredictor(hidden_channels=128).to(device) output = layer(input_dummy, x_mask) - assert list(output.shape)==[8, 1, 27] + assert list(output.shape) == [8, 1, 27] def test_speedy_speech(): @@ -93,6 +95,40 @@ def test_speedy_speech(): # forward pass o_de, o_dr, attn = model(x_dummy, x_lengths, y_lengths, durations) + assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}" + assert list(attn.shape) == [B, T_de, T_en] + assert list(o_dr.shape) == [B, T_en] + + # with speaker embedding + model = SpeedySpeech(num_chars, + out_channels=80, + hidden_channels=128, + num_speakers=10, + c_in_channels=256).to(device) + model.forward(x_dummy, + x_lengths, + y_lengths, + durations, + g=torch.randint(0, 10, (B,)).to(device)) + + assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}" + assert list(attn.shape) == [B, T_de, T_en] + assert list(o_dr.shape) == [B, T_en] + + + # with speaker external embedding + model = SpeedySpeech(num_chars, + out_channels=80, + hidden_channels=128, + num_speakers=10, + external_c=True, + c_in_channels=256).to(device) + model.forward(x_dummy, + x_lengths, + y_lengths, + durations, + g=torch.rand((B,256)).to(device)) + assert list(o_de.shape) == [B, 80, T_de], f"{list(o_de.shape)}" assert list(attn.shape) == [B, T_de, T_en] assert list(o_dr.shape) == [B, T_en] \ No newline at end of file