mirror of https://github.com/coqui-ai/TTS.git
Export multispeaker onnx (#2743)
This commit is contained in:
parent
08bc758cad
commit
7b5c8422c8
|
@ -15,6 +15,8 @@ from einops import pack, unpack
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchaudio.functional import resample
|
from torchaudio.functional import resample
|
||||||
from transformers import HubertModel
|
from transformers import HubertModel
|
||||||
|
|
||||||
|
|
||||||
def round_down_nearest_multiple(num, divisor):
|
def round_down_nearest_multiple(num, divisor):
|
||||||
return num // divisor * divisor
|
return num // divisor * divisor
|
||||||
|
|
||||||
|
|
|
@ -1875,7 +1875,11 @@ class Vits(BaseTTS):
|
||||||
def load_onnx(self, model_path: str, cuda=False):
|
def load_onnx(self, model_path: str, cuda=False):
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"]
|
providers = [
|
||||||
|
"CPUExecutionProvider"
|
||||||
|
if cuda is False
|
||||||
|
else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
|
||||||
|
]
|
||||||
sess_options = ort.SessionOptions()
|
sess_options = ort.SessionOptions()
|
||||||
self.onnx_sess = ort.InferenceSession(
|
self.onnx_sess = ort.InferenceSession(
|
||||||
model_path,
|
model_path,
|
||||||
|
@ -1883,11 +1887,8 @@ class Vits(BaseTTS):
|
||||||
providers=providers,
|
providers=providers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def inference_onnx(self, x, x_lengths=None):
|
def inference_onnx(self, x, x_lengths=None, speaker_id=None):
|
||||||
"""ONNX inference (only single speaker models are supported)
|
"""ONNX inference"""
|
||||||
|
|
||||||
TODO: implement multi speaker support.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
x = x.cpu().numpy()
|
x = x.cpu().numpy()
|
||||||
|
@ -1907,7 +1908,7 @@ class Vits(BaseTTS):
|
||||||
"input": x,
|
"input": x,
|
||||||
"input_lengths": x_lengths,
|
"input_lengths": x_lengths,
|
||||||
"scales": scales,
|
"scales": scales,
|
||||||
"sid": None,
|
"sid": torch.tensor([speaker_id]).cpu().numpy(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return audio[0][0]
|
return audio[0][0]
|
||||||
|
|
Loading…
Reference in New Issue