Export multispeaker onnx (#2743)

This commit is contained in:
Eren Gölge 2023-07-06 13:36:50 +02:00 committed by GitHub
parent 08bc758cad
commit 7b5c8422c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 7 deletions

View File

@ -15,6 +15,8 @@ from einops import pack, unpack
from torch import nn
from torchaudio.functional import resample
from transformers import HubertModel
def round_down_nearest_multiple(num, divisor):
return num // divisor * divisor

View File

@ -1875,7 +1875,11 @@ class Vits(BaseTTS):
def load_onnx(self, model_path: str, cuda=False):
import onnxruntime as ort
providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"]
providers = [
"CPUExecutionProvider"
if cuda is False
else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
]
sess_options = ort.SessionOptions()
self.onnx_sess = ort.InferenceSession(
model_path,
@ -1883,11 +1887,8 @@ class Vits(BaseTTS):
providers=providers,
)
def inference_onnx(self, x, x_lengths=None):
"""ONNX inference (only single speaker models are supported)
TODO: implement multi speaker support.
"""
def inference_onnx(self, x, x_lengths=None, speaker_id=None):
"""ONNX inference"""
if isinstance(x, torch.Tensor):
x = x.cpu().numpy()
@ -1907,7 +1908,7 @@ class Vits(BaseTTS):
"input": x,
"input_lengths": x_lengths,
"scales": scales,
"sid": None,
"sid": torch.tensor([speaker_id]).cpu().numpy(),
},
)
return audio[0][0]