Adds multi-language support for VITS onnx, fixes onnx inference error when speaker_id is None or not passed, fixes onnx exporting for models with init_discriminator=false (#2816)

This commit is contained in:
Javier 2023-07-31 03:19:49 -05:00 committed by GitHub
parent b739326503
commit c140df5a58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 18 additions and 10 deletions

View File

@ -1813,6 +1813,8 @@ class Vits(BaseTTS):
# rollback values # rollback values
_forward = self.forward _forward = self.forward
disc = None
if hasattr(self, 'disc'):
disc = self.disc disc = self.disc
training = self.training training = self.training
@ -1820,7 +1822,7 @@ class Vits(BaseTTS):
self.disc = None self.disc = None
self.eval() self.eval()
def onnx_inference(text, text_lengths, scales, sid=None): def onnx_inference(text, text_lengths, scales, sid=None, langid=None):
noise_scale = scales[0] noise_scale = scales[0]
length_scale = scales[1] length_scale = scales[1]
noise_scale_dp = scales[2] noise_scale_dp = scales[2]
@ -1833,7 +1835,7 @@ class Vits(BaseTTS):
"x_lengths": text_lengths, "x_lengths": text_lengths,
"d_vectors": None, "d_vectors": None,
"speaker_ids": sid, "speaker_ids": sid,
"language_ids": None, "language_ids": langid,
"durations": None, "durations": None,
}, },
)["model_outputs"] )["model_outputs"]
@ -1844,11 +1846,14 @@ class Vits(BaseTTS):
dummy_input_length = 100 dummy_input_length = 100
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long) sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
sequence_lengths = torch.LongTensor([sequences.size(1)]) sequence_lengths = torch.LongTensor([sequences.size(1)])
sepaker_id = None speaker_id = None
language_id = None
if self.num_speakers > 1: if self.num_speakers > 1:
sepaker_id = torch.LongTensor([0]) speaker_id = torch.LongTensor([0])
if self.num_languages > 0 and self.embedded_language_dim > 0:
language_id = torch.LongTensor([0])
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp]) scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
dummy_input = (sequences, sequence_lengths, scales, sepaker_id) dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id)
# export to ONNX # export to ONNX
torch.onnx.export( torch.onnx.export(
@ -1857,7 +1862,7 @@ class Vits(BaseTTS):
opset_version=15, opset_version=15,
f=output_path, f=output_path,
verbose=verbose, verbose=verbose,
input_names=["input", "input_lengths", "scales", "sid"], input_names=["input", "input_lengths", "scales", "sid", "langid"],
output_names=["output"], output_names=["output"],
dynamic_axes={ dynamic_axes={
"input": {0: "batch_size", 1: "phonemes"}, "input": {0: "batch_size", 1: "phonemes"},
@ -1870,6 +1875,7 @@ class Vits(BaseTTS):
self.forward = _forward self.forward = _forward
if training: if training:
self.train() self.train()
if not disc is None:
self.disc = disc self.disc = disc
def load_onnx(self, model_path: str, cuda=False): def load_onnx(self, model_path: str, cuda=False):
@ -1887,7 +1893,7 @@ class Vits(BaseTTS):
providers=providers, providers=providers,
) )
def inference_onnx(self, x, x_lengths=None, speaker_id=None): def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None):
"""ONNX inference""" """ONNX inference"""
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
@ -1902,13 +1908,15 @@ class Vits(BaseTTS):
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp], [self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
dtype=np.float32, dtype=np.float32,
) )
audio = self.onnx_sess.run( audio = self.onnx_sess.run(
["output"], ["output"],
{ {
"input": x, "input": x,
"input_lengths": x_lengths, "input_lengths": x_lengths,
"scales": scales, "scales": scales,
"sid": torch.tensor([speaker_id]).cpu().numpy(), "sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(),
"langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy()
}, },
) )
return audio[0][0] return audio[0][0]