diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index bc16ea63..f1567cf8 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1725,7 +1725,7 @@ class Vits(BaseTTS): assert not self.training def load_fairseq_checkpoint( - self, config, checkpoint_dir, eval=False + self, config, checkpoint_dir, eval=False, strict=True ): # pylint: disable=unused-argument, redefined-builtin """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms Performs some changes for compatibility. @@ -1763,7 +1763,7 @@ class Vits(BaseTTS): ) # load fairseq checkpoint new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file) - self.load_state_dict(new_chk) + self.load_state_dict(new_chk, strict=strict) if eval: self.eval() assert not self.training @@ -1844,17 +1844,22 @@ class Vits(BaseTTS): # set dummy inputs dummy_input_length = 100 - sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long) + sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long) sequence_lengths = torch.LongTensor([sequences.size(1)]) - speaker_id = None - language_id = None - if self.num_speakers > 1: - speaker_id = torch.LongTensor([0]) - if self.num_languages > 0 and self.embedded_language_dim > 0: - language_id = torch.LongTensor([0]) scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) - dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id) - + dummy_input = (sequences, sequence_lengths, scales) + input_names = ["input", "input_lengths", "scales"] + + if self.num_speakers > 0: + speaker_id = torch.LongTensor([0]) + dummy_input += (speaker_id, ) + input_names.append("sid") + + if hasattr(self, 'num_languages') and self.num_languages > 0 and self.embedded_language_dim > 0: + language_id = torch.LongTensor([0]) + dummy_input += (language_id, ) + input_names.append("langid") + # export to ONNX torch.onnx.export( model=self, @@ -1862,7 +1867,7 @@ class Vits(BaseTTS): opset_version=15, f=output_path, verbose=verbose, - input_names=["input", "input_lengths", "scales", "sid", "langid"], + input_names=input_names, output_names=["output"], dynamic_axes={ "input": {0: "batch_size", 1: "phonemes"}, @@ -1870,7 +1875,7 @@ class Vits(BaseTTS): "output": {0: "batch_size", 1: "time1", 2: "time2"}, }, ) - + # rollback self.forward = _forward if training: @@ -1880,7 +1885,7 @@ class Vits(BaseTTS): def load_onnx(self, model_path: str, cuda=False): import onnxruntime as ort - + providers = [ "CPUExecutionProvider" if cuda is False @@ -1908,16 +1913,19 @@ class Vits(BaseTTS): [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], dtype=np.float32, ) - + input_params = { + "input": x, + "input_lengths": x_lengths, + "scales": scales + } + if not speaker_id is None: + input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy() + if not language_id is None: + input_params["langid"] = torch.tensor([language_id]).cpu().numpy() + audio = self.onnx_sess.run( ["output"], - { - "input": x, - "input_lengths": x_lengths, - "scales": scales, - "sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(), - "langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy(), - }, + input_params, ) return audio[0][0]