From 4de797bb11736e329f89e2407d1001f1b9f47844 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 16 May 2023 01:07:56 +0200 Subject: [PATCH] Draft ONNX export for VITS (#2563) * Draft ONNX export for VITS Could not get it work to output variable length sequence * Fixup for onnx constant output * Make style * Remove commented code --- TTS/tts/models/vits.py | 109 +++++++++++++++++++++++++++++++++++++++ TTS/tts/utils/helpers.py | 7 +-- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 73095b34..2e0c32c8 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1758,6 +1758,115 @@ class Vits(BaseTTS): ) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True): + """Export model to ONNX format for inference + + Args: + output_path (str): Path to save the exported model. + verbose (bool): Print verbose information. Defaults to True. + """ + + # rollback values + _forward = self.forward + disc = self.disc + training = self.training + + # set export mode + self.disc = None + self.eval() + + def onnx_inference(text, text_lengths, scales, sid=None): + noise_scale = scales[0] + length_scale = scales[1] + noise_scale_dp = scales[2] + self.noise_scale = noise_scale + self.length_scale = length_scale + self.noise_scale_dp = noise_scale_dp + return self.inference( + text, + aux_input={ + "x_lengths": text_lengths, + "d_vectors": None, + "speaker_ids": sid, + "language_ids": None, + "durations": None, + }, + )["model_outputs"] + + self.forward = onnx_inference + + # set dummy inputs + dummy_input_length = 100 + sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long) + sequence_lengths = torch.LongTensor([sequences.size(1)]) + sepaker_id = None + if self.num_speakers > 1: + sepaker_id = torch.LongTensor([0]) + scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) + dummy_input = (sequences, sequence_lengths, scales, sepaker_id) + + # export to ONNX + torch.onnx.export( + model=self, + args=dummy_input, + opset_version=15, + f=output_path, + verbose=verbose, + input_names=["input", "input_lengths", "scales", "sid"], + output_names=["output"], + dynamic_axes={ + "input": {0: "batch_size", 1: "phonemes"}, + "input_lengths": {0: "batch_size"}, + "output": {0: "batch_size", 1: "time1", 2: "time2"}, + }, + ) + + # rollback + self.forward = _forward + if training: + self.train() + self.disc = disc + + def load_onnx(self, model_path: str, cuda=False): + import onnxruntime as ort + + providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"] + sess_options = ort.SessionOptions() + self.onnx_sess = ort.InferenceSession( + model_path, + sess_options=sess_options, + providers=providers, + ) + + def inference_onnx(self, x, x_lengths=None): + """ONNX inference (only single speaker models are supported) + + TODO: implement multi speaker support. + """ + + if isinstance(x, torch.Tensor): + x = x.cpu().numpy() + + if x_lengths is None: + x_lengths = np.array([x.shape[1]], dtype=np.int64) + + if isinstance(x_lengths, torch.Tensor): + x_lengths = x_lengths.cpu().numpy() + scales = np.array( + [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], + dtype=np.float32, + ) + audio = self.onnx_sess.run( + ["output"], + { + "input": x, + "input_lengths": x_lengths, + "scales": scales, + "sid": None, + }, + ) + return audio[0][0] + ################################## # VITS CHARACTERS diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index b62004c8..56ef2944 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -50,11 +50,10 @@ def sequence_mask(sequence_length, max_len=None): - mask: :math:`[B, T_max]` """ if max_len is None: - max_len = sequence_length.data.max() + max_len = sequence_length.max() seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) # B x T_max - mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) - return mask + return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False): @@ -158,10 +157,8 @@ def generate_path(duration, mask): - mask: :math:'[B, T_en, T_de]` - path: :math:`[B, T_en, T_de]` """ - device = duration.device b, t_x, t_y = mask.shape cum_duration = torch.cumsum(duration, 1) - path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)