mirror of https://github.com/coqui-ai/TTS.git
Draft ONNX export for VITS (#2563)
* Draft ONNX export for VITS Could not get it work to output variable length sequence * Fixup for onnx constant output * Make style * Remove commented code
This commit is contained in:
parent
16c9df0dfe
commit
4de797bb11
|
@ -1758,6 +1758,115 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
|
||||||
|
|
||||||
|
def export_onnx(self, output_path: str = "coqui_vits.onnx", verbose: bool = True):
|
||||||
|
"""Export model to ONNX format for inference
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str): Path to save the exported model.
|
||||||
|
verbose (bool): Print verbose information. Defaults to True.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# rollback values
|
||||||
|
_forward = self.forward
|
||||||
|
disc = self.disc
|
||||||
|
training = self.training
|
||||||
|
|
||||||
|
# set export mode
|
||||||
|
self.disc = None
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
def onnx_inference(text, text_lengths, scales, sid=None):
|
||||||
|
noise_scale = scales[0]
|
||||||
|
length_scale = scales[1]
|
||||||
|
noise_scale_dp = scales[2]
|
||||||
|
self.noise_scale = noise_scale
|
||||||
|
self.length_scale = length_scale
|
||||||
|
self.noise_scale_dp = noise_scale_dp
|
||||||
|
return self.inference(
|
||||||
|
text,
|
||||||
|
aux_input={
|
||||||
|
"x_lengths": text_lengths,
|
||||||
|
"d_vectors": None,
|
||||||
|
"speaker_ids": sid,
|
||||||
|
"language_ids": None,
|
||||||
|
"durations": None,
|
||||||
|
},
|
||||||
|
)["model_outputs"]
|
||||||
|
|
||||||
|
self.forward = onnx_inference
|
||||||
|
|
||||||
|
# set dummy inputs
|
||||||
|
dummy_input_length = 100
|
||||||
|
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
|
||||||
|
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
||||||
|
sepaker_id = None
|
||||||
|
if self.num_speakers > 1:
|
||||||
|
sepaker_id = torch.LongTensor([0])
|
||||||
|
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
|
||||||
|
dummy_input = (sequences, sequence_lengths, scales, sepaker_id)
|
||||||
|
|
||||||
|
# export to ONNX
|
||||||
|
torch.onnx.export(
|
||||||
|
model=self,
|
||||||
|
args=dummy_input,
|
||||||
|
opset_version=15,
|
||||||
|
f=output_path,
|
||||||
|
verbose=verbose,
|
||||||
|
input_names=["input", "input_lengths", "scales", "sid"],
|
||||||
|
output_names=["output"],
|
||||||
|
dynamic_axes={
|
||||||
|
"input": {0: "batch_size", 1: "phonemes"},
|
||||||
|
"input_lengths": {0: "batch_size"},
|
||||||
|
"output": {0: "batch_size", 1: "time1", 2: "time2"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# rollback
|
||||||
|
self.forward = _forward
|
||||||
|
if training:
|
||||||
|
self.train()
|
||||||
|
self.disc = disc
|
||||||
|
|
||||||
|
def load_onnx(self, model_path: str, cuda=False):
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"]
|
||||||
|
sess_options = ort.SessionOptions()
|
||||||
|
self.onnx_sess = ort.InferenceSession(
|
||||||
|
model_path,
|
||||||
|
sess_options=sess_options,
|
||||||
|
providers=providers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inference_onnx(self, x, x_lengths=None):
|
||||||
|
"""ONNX inference (only single speaker models are supported)
|
||||||
|
|
||||||
|
TODO: implement multi speaker support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = x.cpu().numpy()
|
||||||
|
|
||||||
|
if x_lengths is None:
|
||||||
|
x_lengths = np.array([x.shape[1]], dtype=np.int64)
|
||||||
|
|
||||||
|
if isinstance(x_lengths, torch.Tensor):
|
||||||
|
x_lengths = x_lengths.cpu().numpy()
|
||||||
|
scales = np.array(
|
||||||
|
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
audio = self.onnx_sess.run(
|
||||||
|
["output"],
|
||||||
|
{
|
||||||
|
"input": x,
|
||||||
|
"input_lengths": x_lengths,
|
||||||
|
"scales": scales,
|
||||||
|
"sid": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return audio[0][0]
|
||||||
|
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
# VITS CHARACTERS
|
# VITS CHARACTERS
|
||||||
|
|
|
@ -50,11 +50,10 @@ def sequence_mask(sequence_length, max_len=None):
|
||||||
- mask: :math:`[B, T_max]`
|
- mask: :math:`[B, T_max]`
|
||||||
"""
|
"""
|
||||||
if max_len is None:
|
if max_len is None:
|
||||||
max_len = sequence_length.data.max()
|
max_len = sequence_length.max()
|
||||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||||
# B x T_max
|
# B x T_max
|
||||||
mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
|
def segment(x: torch.tensor, segment_indices: torch.tensor, segment_size=4, pad_short=False):
|
||||||
|
@ -158,10 +157,8 @@ def generate_path(duration, mask):
|
||||||
- mask: :math:'[B, T_en, T_de]`
|
- mask: :math:'[B, T_en, T_de]`
|
||||||
- path: :math:`[B, T_en, T_de]`
|
- path: :math:`[B, T_en, T_de]`
|
||||||
"""
|
"""
|
||||||
device = duration.device
|
|
||||||
b, t_x, t_y = mask.shape
|
b, t_x, t_y = mask.shape
|
||||||
cum_duration = torch.cumsum(duration, 1)
|
cum_duration = torch.cumsum(duration, 1)
|
||||||
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
|
||||||
|
|
||||||
cum_duration_flat = cum_duration.view(b * t_x)
|
cum_duration_flat = cum_duration.view(b * t_x)
|
||||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||||
|
|
Loading…
Reference in New Issue