mirror of https://github.com/coqui-ai/TTS.git
Add fairseq onnx support and strict configuration, fixes some onnx errors (#2831)
This commit is contained in:
parent
52a528cfcf
commit
4e7f8cd021
|
@ -1725,7 +1725,7 @@ class Vits(BaseTTS):
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
def load_fairseq_checkpoint(
|
def load_fairseq_checkpoint(
|
||||||
self, config, checkpoint_dir, eval=False
|
self, config, checkpoint_dir, eval=False, strict=True
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
|
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
|
||||||
Performs some changes for compatibility.
|
Performs some changes for compatibility.
|
||||||
|
@ -1763,7 +1763,7 @@ class Vits(BaseTTS):
|
||||||
)
|
)
|
||||||
# load fairseq checkpoint
|
# load fairseq checkpoint
|
||||||
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
|
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
|
||||||
self.load_state_dict(new_chk)
|
self.load_state_dict(new_chk, strict=strict)
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
@ -1844,16 +1844,21 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
# set dummy inputs
|
# set dummy inputs
|
||||||
dummy_input_length = 100
|
dummy_input_length = 100
|
||||||
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
|
sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long)
|
||||||
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
sequence_lengths = torch.LongTensor([sequences.size(1)])
|
||||||
speaker_id = None
|
|
||||||
language_id = None
|
|
||||||
if self.num_speakers > 1:
|
|
||||||
speaker_id = torch.LongTensor([0])
|
|
||||||
if self.num_languages > 0 and self.embedded_language_dim > 0:
|
|
||||||
language_id = torch.LongTensor([0])
|
|
||||||
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
|
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
|
||||||
dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id)
|
dummy_input = (sequences, sequence_lengths, scales)
|
||||||
|
input_names = ["input", "input_lengths", "scales"]
|
||||||
|
|
||||||
|
if self.num_speakers > 0:
|
||||||
|
speaker_id = torch.LongTensor([0])
|
||||||
|
dummy_input += (speaker_id, )
|
||||||
|
input_names.append("sid")
|
||||||
|
|
||||||
|
if hasattr(self, 'num_languages') and self.num_languages > 0 and self.embedded_language_dim > 0:
|
||||||
|
language_id = torch.LongTensor([0])
|
||||||
|
dummy_input += (language_id, )
|
||||||
|
input_names.append("langid")
|
||||||
|
|
||||||
# export to ONNX
|
# export to ONNX
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
|
@ -1862,7 +1867,7 @@ class Vits(BaseTTS):
|
||||||
opset_version=15,
|
opset_version=15,
|
||||||
f=output_path,
|
f=output_path,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
input_names=["input", "input_lengths", "scales", "sid", "langid"],
|
input_names=input_names,
|
||||||
output_names=["output"],
|
output_names=["output"],
|
||||||
dynamic_axes={
|
dynamic_axes={
|
||||||
"input": {0: "batch_size", 1: "phonemes"},
|
"input": {0: "batch_size", 1: "phonemes"},
|
||||||
|
@ -1908,16 +1913,19 @@ class Vits(BaseTTS):
|
||||||
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
|
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
input_params = {
|
||||||
|
"input": x,
|
||||||
|
"input_lengths": x_lengths,
|
||||||
|
"scales": scales
|
||||||
|
}
|
||||||
|
if not speaker_id is None:
|
||||||
|
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
|
||||||
|
if not language_id is None:
|
||||||
|
input_params["langid"] = torch.tensor([language_id]).cpu().numpy()
|
||||||
|
|
||||||
audio = self.onnx_sess.run(
|
audio = self.onnx_sess.run(
|
||||||
["output"],
|
["output"],
|
||||||
{
|
input_params,
|
||||||
"input": x,
|
|
||||||
"input_lengths": x_lengths,
|
|
||||||
"scales": scales,
|
|
||||||
"sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(),
|
|
||||||
"langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy(),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return audio[0][0]
|
return audio[0][0]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue