mirror of https://github.com/coqui-ai/TTS.git
Export multispeaker onnx (#2743)
This commit is contained in:
parent
08bc758cad
commit
7b5c8422c8
|
@ -15,6 +15,8 @@ from einops import pack, unpack
|
|||
from torch import nn
|
||||
from torchaudio.functional import resample
|
||||
from transformers import HubertModel
|
||||
|
||||
|
||||
def round_down_nearest_multiple(num, divisor):
|
||||
return num // divisor * divisor
|
||||
|
||||
|
|
|
@ -1875,7 +1875,11 @@ class Vits(BaseTTS):
|
|||
def load_onnx(self, model_path: str, cuda=False):
|
||||
import onnxruntime as ort
|
||||
|
||||
providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"]
|
||||
providers = [
|
||||
"CPUExecutionProvider"
|
||||
if cuda is False
|
||||
else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
|
||||
]
|
||||
sess_options = ort.SessionOptions()
|
||||
self.onnx_sess = ort.InferenceSession(
|
||||
model_path,
|
||||
|
@ -1883,11 +1887,8 @@ class Vits(BaseTTS):
|
|||
providers=providers,
|
||||
)
|
||||
|
||||
def inference_onnx(self, x, x_lengths=None):
|
||||
"""ONNX inference (only single speaker models are supported)
|
||||
|
||||
TODO: implement multi speaker support.
|
||||
"""
|
||||
def inference_onnx(self, x, x_lengths=None, speaker_id=None):
|
||||
"""ONNX inference"""
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.cpu().numpy()
|
||||
|
@ -1907,7 +1908,7 @@ class Vits(BaseTTS):
|
|||
"input": x,
|
||||
"input_lengths": x_lengths,
|
||||
"scales": scales,
|
||||
"sid": None,
|
||||
"sid": torch.tensor([speaker_id]).cpu().numpy(),
|
||||
},
|
||||
)
|
||||
return audio[0][0]
|
||||
|
|
Loading…
Reference in New Issue