diff --git a/tests/tts_tests/test_fast_pitch_e2e.py b/tests/tts_tests/test_fast_pitch_e2e.py index f9526156..ca34dec9 100644 --- a/tests/tts_tests/test_fast_pitch_e2e.py +++ b/tests/tts_tests/test_fast_pitch_e2e.py @@ -6,8 +6,8 @@ import torch from trainer.logging.tensorboard_logger import TensorboardLogger from tests import assertHasAttr, assertHasNotAttr, get_tests_data_path, get_tests_input_path, get_tests_output_path -from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2EConfig -from TTS.tts.models.forward_tts_e2e import ForwardTTSE2E, ForwardTTSE2EArgs +from TTS.tts.configs.fast_pitch_e2e_config import FastPitchE2eConfig +from TTS.tts.models.forward_tts_e2e import ForwardTTSE2e, ForwardTTSE2eArgs LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") @@ -18,17 +18,20 @@ torch.manual_seed(1) use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +MAX_INPUT_LEN = 57 +MAX_SPEC_LEN = 33 + # pylint: disable=no-self-use class TestFastPitchE2E(unittest.TestCase): def _create_inputs(self, config, batch_size=2): - input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) - input_lengths[-1] = 128 - spec = torch.rand(batch_size, 30, config.audio["num_mels"]).to(device) - # spec = torch.rand(batch_size, config.audio["num_mels"], 30).to(device) - spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) - spec_lengths[-1] = spec.size(1) + + input_dummy = torch.randint(0, 24, (batch_size, MAX_INPUT_LEN)).long().to(device) + input_lengths = torch.randint(10, MAX_INPUT_LEN, (batch_size,)).long().to(device) + input_lengths[-1] = MAX_INPUT_LEN + spec = torch.rand(batch_size, MAX_SPEC_LEN, config.audio["num_mels"]).to(device) + spec_lengths = torch.randint(20, MAX_SPEC_LEN, (batch_size,)).long().to(device) + spec_lengths[-1] = MAX_SPEC_LEN waveform = torch.rand(batch_size, 1, spec.size(1) * config.audio["hop_length"]).to(device) pitch = torch.rand(batch_size, 1, spec.size(1)).to(device) return input_dummy, input_lengths, spec, spec_lengths, waveform, pitch @@ -37,7 +40,7 @@ class TestFastPitchE2E(unittest.TestCase): self.assertEqual( output_dict["model_outputs"].shape[2], config.model_args.spec_segment_size * config.audio["hop_length"] ) - self.assertEqual(output_dict["alignments"].shape, (batch_size, 30, 128)) + self.assertEqual(output_dict["alignments"].shape, (batch_size, MAX_SPEC_LEN, MAX_INPUT_LEN)) self.assertEqual(output_dict["alignments"].max(), 1) self.assertEqual(output_dict["alignments"].min(), 0) self.assertEqual( @@ -45,7 +48,9 @@ class TestFastPitchE2E(unittest.TestCase): ) def _check_inference_outputs(self, outputs, input_dummy, batch_size=1): - feat_len = outputs["encoder_outputs"].shape[1] + feat_dim = 256 # hard-coded based on model architecture + feat_len = outputs["o_en_ex"].shape[2] + self.assertEqual(outputs["o_en_ex"].shape, (batch_size, feat_dim, feat_len)) self.assertEqual(outputs["model_outputs"].shape[:2], (batch_size, 1)) # we don't know the channel dimension self.assertEqual(outputs["alignments"].shape, (batch_size, input_dummy.shape[1], feat_len)) @@ -68,194 +73,188 @@ class TestFastPitchE2E(unittest.TestCase): batch["text_lengths"] = input_lengths batch["mel_lengths"] = spec_lengths batch["mel_input"] = spec - batch["waveform"] = waveform.transpose(1, 2) # B x C X T -> B x T x C + batch["waveform"] = waveform # B x C X T batch["d_vectors"] = None batch["speaker_ids"] = None batch["language_ids"] = None batch["pitch"] = pitch return batch - # def test_init_multispeaker(self): + def test_init_multispeaker(self): - # num_speakers = 10 - # model_args = ForwardTTSE2EArgs() - # model_args.num_speakers = num_speakers - # model_args.use_speaker_embedding = True - # model = ForwardTTSE2E(model_args) - # assertHasAttr(self, model.encoder_model, "emb_g") + num_speakers = 10 + model_args = ForwardTTSE2eArgs() + model_args.num_speakers = num_speakers + model_args.use_speaker_embedding = True + model = ForwardTTSE2e(model_args) + assertHasAttr(self, model.encoder_model, "emb_g") - # model_args = ForwardTTSE2EArgs() - # model_args.num_speakers = 0 - # model_args.use_speaker_embedding = True - # model = ForwardTTSE2E(model_args) - # assertHasNotAttr(self, model.encoder_model, "emb_g") + model_args = ForwardTTSE2eArgs() + model_args.num_speakers = 0 + model_args.use_speaker_embedding = True + model = ForwardTTSE2e(model_args) + assertHasNotAttr(self, model.encoder_model, "emb_g") - # model_args = ForwardTTSE2EArgs() - # model_args.num_speakers = 10 - # model_args.use_speaker_embedding = False - # model = ForwardTTSE2E(model_args) - # assertHasNotAttr(self, model.encoder_model, "emb_g") + model_args = ForwardTTSE2eArgs() + model_args.num_speakers = 10 + model_args.use_speaker_embedding = False + model = ForwardTTSE2e(model_args) + assertHasNotAttr(self, model.encoder_model, "emb_g") - # model_args = ForwardTTSE2EArgs(d_vector_dim=101, use_d_vector_file=True) - # model = ForwardTTSE2E(model_args) - # self.assertEqual(model.encoder_model.embedded_speaker_dim, 101) + model_args = ForwardTTSE2eArgs(d_vector_dim=101, use_d_vector_file=True) + model = ForwardTTSE2e(model_args) + self.assertEqual(model.encoder_model.embedded_speaker_dim, 101) - # def test_init_multilingual(self): + def test_init_multilingual(self): + """TODO""" + + def test_get_aux_input(self): + aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None} + model_args = ForwardTTSE2eArgs() + model = ForwardTTSE2e(model_args) + aux_out = model.get_aux_input(aux_input) + + speaker_id = torch.randint(10, (1,)) + language_id = torch.randint(10, (1,)) + d_vector = torch.rand(1, 128) + aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id} + aux_out = model.get_aux_input(aux_input) + self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape) + self.assertEqual(aux_out["language_ids"].shape, language_id.shape) + self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape) + + def test_forward(self): + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs(config) + model = ForwardTTSE2e(config).to(device) + output_dict = model.forward( + x=input_dummy, x_lengths=input_lengths, spec=spec, spec_lengths=spec_lengths, waveform=waveform, pitch=pitch + ) + self._check_forward_outputs(config, output_dict) + + def test_multispeaker_forward(self): + batch_size = 2 + num_speakers = 10 + model_args = ForwardTTSE2eArgs(spec_segment_size=10, num_speakers=num_speakers, use_speaker_embedding=True) + config = FastPitchE2eConfig(model_args=model_args) + config.model_args.spec_segment_size = 10 + + input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs( + config, batch_size=batch_size + ) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + + model = ForwardTTSE2e(config).to(device) + output_dict = model.forward( + x=input_dummy, + x_lengths=input_lengths, + spec=spec, + spec_lengths=spec_lengths, + waveform=waveform, + pitch=pitch, + aux_input={"speaker_ids": speaker_ids}, + ) + self._check_forward_outputs(config, output_dict) + + def test_d_vector_forward(self): + batch_size = 2 + model_args = ForwardTTSE2eArgs(spec_segment_size=10, use_d_vector_file=True, d_vector_dim=256) + config = FastPitchE2eConfig(model_args=model_args) + config.model_args.spec_segment_size = 10 + model = ForwardTTSE2e(config).to(device) + model.train() + input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs( + config, batch_size=batch_size + ) + d_vectors = torch.randn(batch_size, 256).to(device) + output_dict = model.forward( + x=input_dummy, + x_lengths=input_lengths, + spec=spec, + spec_lengths=spec_lengths, + waveform=waveform, + pitch=pitch, + aux_input={"d_vectors": d_vectors}, + ) + self._check_forward_outputs(config, output_dict) + + # def test_multilingual_forward(self): # """TODO""" - # def test_get_aux_input(self): - # aux_input = {"speaker_ids": None, "style_wav": None, "d_vectors": None, "language_ids": None} - # model_args = ForwardTTSE2EArgs() - # model = ForwardTTSE2E(model_args) - # aux_out = model.get_aux_input(aux_input) + def test_inference(self): + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e(config).to(device) + model.eval() - # speaker_id = torch.randint(10, (1,)) - # language_id = torch.randint(10, (1,)) - # d_vector = torch.rand(1, 128) - # aux_input = {"speaker_ids": speaker_id, "style_wav": None, "d_vectors": d_vector, "language_ids": language_id} - # aux_out = model.get_aux_input(aux_input) - # self.assertEqual(aux_out["speaker_ids"].shape, speaker_id.shape) - # self.assertEqual(aux_out["language_ids"].shape, language_id.shape) - # self.assertEqual(aux_out["d_vectors"].shape, d_vector.unsqueeze(0).transpose(2, 1).shape) + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + outputs = model.inference(input_dummy.to(device)) + self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) - # def test_forward(self): - # model_args = ForwardTTSE2EArgs(spec_segment_size=10) - # config = FastPitchE2EConfig(model_args=model_args) - # input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs(config) - # model = ForwardTTSE2E(config).to(device) - # output_dict = model.forward( - # x=input_dummy, x_lengths=input_lengths, spec=spec, spec_lengths=spec_lengths, waveform=waveform, pitch=pitch - # ) - # self._check_forward_outputs(config, output_dict) + # TODO implemented batched inferenece + # batch_size = 2 + # input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + # outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths}) + # self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) - # def test_multispeaker_forward(self): - # batch_size = 2 - # num_speakers = 10 - # model_args = ForwardTTSE2EArgs( - # spec_segment_size=10, num_speakers=num_speakers, use_speaker_embedding=True - # ) - # config = FastPitchE2EConfig(model_args=model_args) - # config.model_args.spec_segment_size = 10 + def test_multispeaker_inference(self): + num_speakers = 10 + model_args = ForwardTTSE2eArgs(spec_segment_size=10, num_speakers=num_speakers, use_speaker_embedding=True) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e(config).to(device) - # input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs( - # config, batch_size=batch_size - # ) - # speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + batch_size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) + speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids}) + self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) - # model = ForwardTTSE2E(config).to(device) - # output_dict = model.forward( - # x=input_dummy, - # x_lengths=input_lengths, - # spec=spec, - # spec_lengths=spec_lengths, - # waveform=waveform, - # pitch=pitch, - # aux_input={"speaker_ids": speaker_ids}, - # ) - # self._check_forward_outputs(config, output_dict) + # batch_size = 2 + # input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) + # speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) + # outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) + # self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) - # def test_d_vector_forward(self): - # batch_size = 2 - # model_args = ForwardTTSE2EArgs( - # spec_segment_size=10, use_d_vector_file=True, d_vector_dim=256 - # ) - # config = FastPitchE2EConfig(model_args=model_args) - # config.model_args.spec_segment_size = 10 - # model = ForwardTTSE2E(config).to(device) - # model.train() - # input_dummy, input_lengths, spec, spec_lengths, waveform, pitch = self._create_inputs( - # config, batch_size=batch_size - # ) - # d_vectors = torch.randn(batch_size, 256).to(device) - # output_dict = model.forward( - # x=input_dummy, - # x_lengths=input_lengths, - # spec=spec, - # spec_lengths=spec_lengths, - # waveform=waveform, - # pitch=pitch, - # aux_input={"d_vectors": d_vectors}, - # ) - # self._check_forward_outputs(config, output_dict) + # def test_multilingual_inference(self): + # """TODO""" - # # def test_multilingual_forward(self): - # # """TODO""" - - # def test_inference(self): - # model_args = ForwardTTSE2EArgs(spec_segment_size=10) - # config = FastPitchE2EConfig(model_args=model_args) - # model = ForwardTTSE2E(config).to(device) - # model.eval() - - # batch_size = 1 - # input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) - # outputs = model.inference(input_dummy.to(device)) - # self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) - - # # TODO implemented batched inferenece - # # batch_size = 2 - # # input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) - # # outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths}) - # # self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) - - # def test_multispeaker_inference(self): - # num_speakers = 10 - # model_args = ForwardTTSE2EArgs( - # spec_segment_size=10, num_speakers=num_speakers, use_speaker_embedding=True - # ) - # config = FastPitchE2EConfig(model_args=model_args) - # model = ForwardTTSE2E(config).to(device) - - # batch_size = 1 - # input_dummy, *_ = self._create_inputs(config, batch_size=batch_size) - # speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) - # outputs = model.inference(input_dummy, {"speaker_ids": speaker_ids}) - # self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) - - # # batch_size = 2 - # # input_dummy, input_lengths, *_ = self._create_inputs(config, batch_size=batch_size) - # # speaker_ids = torch.randint(0, num_speakers, (batch_size,)).long().to(device) - # # outputs = model.inference(input_dummy, {"x_lengths": input_lengths, "speaker_ids": speaker_ids}) - # # self._check_inference_outputs(outputs, input_dummy, batch_size=batch_size) - - # # def test_multilingual_inference(self): - # # """TODO""" - - # def test_d_vector_inference(self): - # model_args = ForwardTTSE2EArgs( - # spec_segment_size=10, - # num_chars=32, - # use_d_vector_file=True, - # d_vector_dim=256, - # d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), - # ) - # config = FastPitchE2EConfig(model_args=model_args) - # model = ForwardTTSE2E(config).to(device) - # model.eval() - # # batch size = 1 - # input_dummy = torch.randint(0, 24, (1, 128)).long().to(device) - # d_vectors = torch.randn(1, 256).to(device) - # outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors}) - # self._check_inference_outputs(outputs, input_dummy) - # # batch size = 2 - # # input_dummy, input_lengths, *_ = self._create_inputs(config) - # # d_vectors = torch.randn(2, 256).to(device) - # # outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths, "d_vectors": d_vectors}) - # # self._check_inference_outputs(outputs, input_dummy, batch_size=2) + def test_d_vector_inference(self): + model_args = ForwardTTSE2eArgs( + spec_segment_size=10, + num_chars=32, + use_d_vector_file=True, + d_vector_dim=256, + d_vector_file=os.path.join(get_tests_data_path(), "dummy_speakers.json"), + ) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e(config).to(device) + model.eval() + # batch size = 1 + input_dummy, *_ = self._create_inputs(config, batch_size=1) + d_vectors = torch.randn(1, 256).to(device) + outputs = model.inference(input_dummy, aux_input={"d_vectors": d_vectors}) + self._check_inference_outputs(outputs, input_dummy) + # batch size = 2 + # input_dummy, input_lengths, *_ = self._create_inputs(config) + # d_vectors = torch.randn(2, 256).to(device) + # outputs = model.inference(input_dummy, aux_input={"x_lengths": input_lengths, "d_vectors": d_vectors}) + # self._check_inference_outputs(outputs, input_dummy, batch_size=2) def test_train_step(self): # setup the model with torch.autograd.set_detect_anomaly(True): - model_args = ForwardTTSE2EArgs(spec_segment_size=10) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E(config).to(device) + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e(config).to(device) model.train() # model to train optimizers = model.get_optimizer() criterions = model.get_criterion() criterions = [criterions[0].to(device), criterions[1].to(device)] # reference model to compare model weights - model_ref = ForwardTTSE2E(config).to(device) + model_ref = ForwardTTSE2e(config).to(device) # # pass the state to ref model model_ref.load_state_dict(copy.deepcopy(model.state_dict())) count = 0 @@ -277,10 +276,11 @@ class TestFastPitchE2E(unittest.TestCase): def test_train_eval_log(self): batch_size = 2 - model_args = ForwardTTSE2EArgs(spec_segment_size=10) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) model.train() + model.on_init_start(trainer=None) # create mel_basis batch = self._create_batch(config, batch_size) logger = TensorboardLogger( log_dir=os.path.join(get_tests_output_path(), "dummy_fast_pitch_e2e_logs"), @@ -296,19 +296,20 @@ class TestFastPitchE2E(unittest.TestCase): logger.finish() def test_test_run(self): - model_args = ForwardTTSE2EArgs(spec_segment_size=10) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) model.eval() + model.on_init_start(trainer=None) # create mel_basis test_figures, test_audios = model.test_run(None) self.assertTrue(test_figures is not None) self.assertTrue(test_audios is not None) def test_load_checkpoint(self): chkp_path = os.path.join(get_tests_output_path(), "dummy_fast_pitch_e2e_tts_checkpoint.pth") - model_args = ForwardTTSE2EArgs(spec_segment_size=10) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) chkp = {} chkp["model"] = model.state_dict() torch.save(chkp, chkp_path) @@ -318,49 +319,47 @@ class TestFastPitchE2E(unittest.TestCase): self.assertFalse(model.training) def test_get_criterion(self): - model_args = ForwardTTSE2EArgs(spec_segment_size=10) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) criterion = model.get_criterion() self.assertTrue(criterion is not None) def test_init_from_config(self): - model_args = ForwardTTSE2EArgs(spec_segment_size=10) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + model_args = ForwardTTSE2eArgs(spec_segment_size=10) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) - model_args = ForwardTTSE2EArgs(spec_segment_size=10, num_speakers=2) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + model_args = ForwardTTSE2eArgs(spec_segment_size=10, num_speakers=2) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) self.assertTrue(not hasattr(model, "emb_g")) - model_args = ForwardTTSE2EArgs( - spec_segment_size=10, num_speakers=2, use_speaker_embedding=True - ) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + model_args = ForwardTTSE2eArgs(spec_segment_size=10, num_speakers=2, use_speaker_embedding=True) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) self.assertEqual(model.num_speakers, 2) self.assertTrue(hasattr(model, "emb_g")) - model_args = ForwardTTSE2EArgs( - spec_segment_size=10, + model_args = ForwardTTSE2eArgs( + spec_segment_size=10, num_speakers=2, use_speaker_embedding=True, speakers_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) self.assertEqual(model.num_speakers, 10) self.assertTrue(hasattr(model, "emb_g")) - model_args = ForwardTTSE2EArgs( - spec_segment_size=10, + model_args = ForwardTTSE2eArgs( + spec_segment_size=10, use_d_vector_file=True, d_vector_dim=256, d_vector_file=os.path.join(get_tests_data_path(), "ljspeech", "speakers.json"), ) - config = FastPitchE2EConfig(model_args=model_args) - model = ForwardTTSE2E.init_from_config(config, verbose=False).to(device) + config = FastPitchE2eConfig(model_args=model_args) + model = ForwardTTSE2e.init_from_config(config, verbose=False).to(device) self.assertTrue(model.num_speakers == 10) self.assertTrue(not hasattr(model, "emb_g")) self.assertTrue(model.embedded_speaker_dim == config.model_args.d_vector_dim)