Export multispeaker onnx (#2743)

This commit is contained in:
Eren Gölge 2023-07-06 13:36:50 +02:00 committed by GitHub
parent 08bc758cad
commit 7b5c8422c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 7 deletions

View File

@ -15,6 +15,8 @@ from einops import pack, unpack
from torch import nn from torch import nn
from torchaudio.functional import resample from torchaudio.functional import resample
from transformers import HubertModel from transformers import HubertModel
def round_down_nearest_multiple(num, divisor): def round_down_nearest_multiple(num, divisor):
return num // divisor * divisor return num // divisor * divisor

View File

@ -1875,7 +1875,11 @@ class Vits(BaseTTS):
def load_onnx(self, model_path: str, cuda=False): def load_onnx(self, model_path: str, cuda=False):
import onnxruntime as ort import onnxruntime as ort
providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"] providers = [
"CPUExecutionProvider"
if cuda is False
else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
]
sess_options = ort.SessionOptions() sess_options = ort.SessionOptions()
self.onnx_sess = ort.InferenceSession( self.onnx_sess = ort.InferenceSession(
model_path, model_path,
@ -1883,11 +1887,8 @@ class Vits(BaseTTS):
providers=providers, providers=providers,
) )
def inference_onnx(self, x, x_lengths=None): def inference_onnx(self, x, x_lengths=None, speaker_id=None):
"""ONNX inference (only single speaker models are supported) """ONNX inference"""
TODO: implement multi speaker support.
"""
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x = x.cpu().numpy() x = x.cpu().numpy()
@ -1907,7 +1908,7 @@ class Vits(BaseTTS):
"input": x, "input": x,
"input_lengths": x_lengths, "input_lengths": x_lengths,
"scales": scales, "scales": scales,
"sid": None, "sid": torch.tensor([speaker_id]).cpu().numpy(),
}, },
) )
return audio[0][0] return audio[0][0]