mirror of https://github.com/coqui-ai/TTS.git
Adds multi-language support for VITS onnx, fixes onnx inference error when speaker_id is None or not passed, fixes onnx exporting for models with init_discriminator=false (#2816)
This commit is contained in:
parent
b739326503
commit
c140df5a58
|
@ -1813,14 +1813,16 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# rollback values
|
# rollback values
|
||||||
_forward = self.forward
|
_forward = self.forward
|
||||||
disc = self.disc
|
disc = None
|
||||||
|
if hasattr(self, 'disc'):
|
||||||
|
disc = self.disc
|
||||||
training = self.training
|
training = self.training
|
||||||
|
|
||||||
# set export mode
|
# set export mode
|
||||||
self.disc = None
|
self.disc = None
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
def onnx_inference(text, text_lengths, scales, sid=None):
|
def onnx_inference(text, text_lengths, scales, sid=None, langid=None):
|
||||||
noise_scale = scales[0]
|
noise_scale = scales[0]
|
||||||
length_scale = scales[1]
|
length_scale = scales[1]
|
||||||
noise_scale_dp = scales[2]
|
noise_scale_dp = scales[2]
|
||||||
|
@ -1833,7 +1835,7 @@ class Vits(BaseTTS):
|
||||||
"x_lengths": text_lengths,
|
"x_lengths": text_lengths,
|
||||||
"d_vectors": None,
|
"d_vectors": None,
|
||||||
"speaker_ids": sid,
|
"speaker_ids": sid,
|
||||||
"language_ids": None,
|
"language_ids": langid,
|
||||||
"durations": None,
|
"durations": None,
|
||||||
},
|
},
|
||||||
)["model_outputs"]
|
)["model_outputs"]
|
||||||
|
@ -1844,11 +1846,14 @@ class Vits(BaseTTS):
|
||||||
dummy_input_length = 100
|
dummy_input_length = 100
|
||||||
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
|
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
|
||||||
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
||||||
sepaker_id = None
|
speaker_id = None
|
||||||
|
language_id = None
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
sepaker_id = torch.LongTensor([0])
|
speaker_id = torch.LongTensor([0])
|
||||||
|
if self.num_languages > 0 and self.embedded_language_dim > 0:
|
||||||
|
language_id = torch.LongTensor([0])
|
||||||
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
|
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
|
||||||
dummy_input = (sequences, sequence_lengths, scales, sepaker_id)
|
dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id)
|
||||||
|
|
||||||
# export to ONNX
|
# export to ONNX
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
|
@ -1857,7 +1862,7 @@ class Vits(BaseTTS):
|
||||||
opset_version=15,
|
opset_version=15,
|
||||||
f=output_path,
|
f=output_path,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
input_names=["input", "input_lengths", "scales", "sid"],
|
input_names=["input", "input_lengths", "scales", "sid", "langid"],
|
||||||
output_names=["output"],
|
output_names=["output"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"input": {0: "batch_size", 1: "phonemes"},
|
"input": {0: "batch_size", 1: "phonemes"},
|
||||||
|
@ -1870,7 +1875,8 @@ class Vits(BaseTTS):
|
||||||
self.forward = _forward
|
self.forward = _forward
|
||||||
if training:
|
if training:
|
||||||
self.train()
|
self.train()
|
||||||
self.disc = disc
|
if not disc is None:
|
||||||
|
self.disc = disc
|
||||||
|
|
||||||
def load_onnx(self, model_path: str, cuda=False):
|
def load_onnx(self, model_path: str, cuda=False):
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
@ -1887,7 +1893,7 @@ class Vits(BaseTTS):
|
||||||
providers=providers,
|
providers=providers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def inference_onnx(self, x, x_lengths=None, speaker_id=None):
|
def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None):
|
||||||
"""ONNX inference"""
|
"""ONNX inference"""
|
||||||
|
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
|
@ -1902,13 +1908,15 @@ class Vits(BaseTTS):
|
||||||
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
|
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio = self.onnx_sess.run(
|
audio = self.onnx_sess.run(
|
||||||
["output"],
|
["output"],
|
||||||
{
|
{
|
||||||
"input": x,
|
"input": x,
|
||||||
"input_lengths": x_lengths,
|
"input_lengths": x_lengths,
|
||||||
"scales": scales,
|
"scales": scales,
|
||||||
"sid": torch.tensor([speaker_id]).cpu().numpy(),
|
"sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(),
|
||||||
|
"langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return audio[0][0]
|
return audio[0][0]
|
||||||
|
|
Loading…
Reference in New Issue