From 7b5c8422c83758e67ab0e2d0d3559a12a8321732 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 6 Jul 2023 13:36:50 +0200 Subject: [PATCH] Export multispeaker onnx (#2743) --- TTS/tts/layers/bark/hubert/kmeans_hubert.py | 2 ++ TTS/tts/models/vits.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index c7724c23..a6a3b9ae 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -15,6 +15,8 @@ from einops import pack, unpack from torch import nn from torchaudio.functional import resample from transformers import HubertModel + + def round_down_nearest_multiple(num, divisor): return num // divisor * divisor diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index bc96f5dc..f4f4c639 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1875,7 +1875,11 @@ class Vits(BaseTTS): def load_onnx(self, model_path: str, cuda=False): import onnxruntime as ort - providers = ["CPUExecutionProvider" if cuda is False else "CUDAExecutionProvider"] + providers = [ + "CPUExecutionProvider" + if cuda is False + else ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}) + ] sess_options = ort.SessionOptions() self.onnx_sess = ort.InferenceSession( model_path, @@ -1883,11 +1887,8 @@ class Vits(BaseTTS): providers=providers, ) - def inference_onnx(self, x, x_lengths=None): - """ONNX inference (only single speaker models are supported) - - TODO: implement multi speaker support. - """ + def inference_onnx(self, x, x_lengths=None, speaker_id=None): + """ONNX inference""" if isinstance(x, torch.Tensor): x = x.cpu().numpy() @@ -1907,7 +1908,7 @@ class Vits(BaseTTS): "input": x, "input_lengths": x_lengths, "scales": scales, - "sid": None, + "sid": torch.tensor([speaker_id]).cpu().numpy(), }, ) return audio[0][0]