diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 7d474c20..ccc3be1c 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -134,8 +134,8 @@ class TestVits(unittest.TestCase): ref_inp = torch.randn(1, 513, spec_len) ref_inp_len = torch.randint(1, spec_effective_len, (1,)) - ref_spk_id = torch.randint(1, num_speakers, (1,)) - tgt_spk_id = torch.randint(1, num_speakers, (1,)) + ref_spk_id = torch.randint(1, num_speakers, (1,)).item() + tgt_spk_id = torch.randint(1, num_speakers, (1,)).item() o_hat, y_mask, (z, z_p, z_hat) = model.voice_conversion(ref_inp, ref_inp_len, ref_spk_id, tgt_spk_id) self.assertEqual(o_hat.shape, (1, 1, spec_len * 256))