Adds multi-language support for VITS onnx, fixes onnx inference error when speaker_id is None or not passed, fixes onnx exporting for models with init_discriminator=false (#2816)

This commit is contained in:
Javier 2023-07-31 03:19:49 -05:00 committed by GitHub
parent b739326503
commit c140df5a58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 10 deletions

View File

@ -1813,14 +1813,16 @@ class Vits(BaseTTS):
# rollback values
_forward = self.forward
disc = self.disc
disc = None
if hasattr(self, 'disc'):
disc = self.disc
training = self.training
# set export mode
self.disc = None
self.eval()
def onnx_inference(text, text_lengths, scales, sid=None):
def onnx_inference(text, text_lengths, scales, sid=None, langid=None):
noise_scale = scales[0]
length_scale = scales[1]
noise_scale_dp = scales[2]
@ -1833,7 +1835,7 @@ class Vits(BaseTTS):
"x_lengths": text_lengths,
"d_vectors": None,
"speaker_ids": sid,
"language_ids": None,
"language_ids": langid,
"durations": None,
},
)["model_outputs"]
@ -1844,11 +1846,14 @@ class Vits(BaseTTS):
dummy_input_length = 100
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
sequence_lengths = torch.LongTensor([sequences.size(1)])
sepaker_id = None
speaker_id = None
language_id = None
if self.num_speakers > 1:
sepaker_id = torch.LongTensor([0])
speaker_id = torch.LongTensor([0])
if self.num_languages > 0 and self.embedded_language_dim > 0:
language_id = torch.LongTensor([0])
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
dummy_input = (sequences, sequence_lengths, scales, sepaker_id)
dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id)
# export to ONNX
torch.onnx.export(
@ -1857,7 +1862,7 @@ class Vits(BaseTTS):
opset_version=15,
f=output_path,
verbose=verbose,
input_names=["input", "input_lengths", "scales", "sid"],
input_names=["input", "input_lengths", "scales", "sid", "langid"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "phonemes"},
@ -1870,7 +1875,8 @@ class Vits(BaseTTS):
self.forward = _forward
if training:
self.train()
self.disc = disc
if not disc is None:
self.disc = disc
def load_onnx(self, model_path: str, cuda=False):
import onnxruntime as ort
@ -1887,7 +1893,7 @@ class Vits(BaseTTS):
providers=providers,
)
def inference_onnx(self, x, x_lengths=None, speaker_id=None):
def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None):
"""ONNX inference"""
if isinstance(x, torch.Tensor):
@ -1902,13 +1908,15 @@ class Vits(BaseTTS):
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
dtype=np.float32,
)
audio = self.onnx_sess.run(
["output"],
{
"input": x,
"input_lengths": x_lengths,
"scales": scales,
"sid": torch.tensor([speaker_id]).cpu().numpy(),
"sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(),
"langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy()
},
)
return audio[0][0]