From febcaf710a4fd5bdb82040abc908df765a2e4555 Mon Sep 17 00:00:00 2001 From: Julian Weber Date: Mon, 14 Aug 2023 21:02:48 +0200 Subject: [PATCH 1/6] Add customizable data home path (#2871) * Add customizable data home path * Add TTS_HOME as an option --- TTS/utils/generic_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index b8269883..f8a11ac5 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -126,7 +126,13 @@ def get_import_path(obj: object) -> str: def get_user_data_dir(appname): - if sys.platform == "win32": + TTS_HOME = os.environ.get("TTS_HOME") + XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME") + if TTS_HOME is not None: + ans = Path(TTS_HOME).expanduser().resolve(strict=False) + elif XDG_DATA_HOME is not None: + ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False) + elif sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel key = winreg.OpenKey( From 409db505d24debc744e9bee277753b85bd2b53bf Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Mon, 14 Aug 2023 15:04:44 -0400 Subject: [PATCH 2/6] Add device support in TTS and Synthesizer (#2855) * fix: resolve merge conflicts * fix: retain backwards compatability in functions * feature: utilize device for voice transfer * feature: use device for vocoder * chore: cleanup vocoder cpu logic * fix: add necessary vocoder output device check * fix: add necessary vocoder output device check * fix: indentation * fix: check if waveform is pt tensor before cpu conversion --------- Co-authored-by: Jake Tae --- TTS/api.py | 8 ++++- TTS/tts/utils/synthesis.py | 66 ++++++++++++++++++++++---------------- TTS/utils/synthesizer.py | 19 +++++++---- 3 files changed, 58 insertions(+), 35 deletions(-) diff --git a/TTS/api.py b/TTS/api.py index 5bb91362..2ee108ba 100644 --- a/TTS/api.py +++ b/TTS/api.py @@ -1,8 +1,10 @@ import tempfile +import warnings from pathlib import Path from typing import Union import numpy as np +from torch import nn from TTS.cs_api import CS_API from TTS.utils.audio.numpy_transforms import save_wav @@ -10,7 +12,7 @@ from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer -class TTS: +class TTS(nn.Module): """TODO: Add voice conversion and Capacitron support.""" def __init__( @@ -62,6 +64,7 @@ class TTS: Defaults to "XTTS". gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False. """ + super().__init__() self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False) self.synthesizer = None @@ -70,6 +73,9 @@ class TTS: self.cs_api_model = cs_api_model self.model_name = None + if gpu: + warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.") + if model_name is not None: if "tts_models" in model_name or "coqui_studio" in model_name: self.load_tts_model_by_name(model_name, gpu) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 039816db..d3c29f93 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -5,19 +5,21 @@ import torch from torch import nn -def numpy_to_torch(np_array, dtype, cuda=False): +def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"): + if cuda: + device = "cuda" if np_array is None: return None - tensor = torch.as_tensor(np_array, dtype=dtype) - if cuda: - return tensor.cuda() + tensor = torch.as_tensor(np_array, dtype=dtype, device=device) return tensor -def compute_style_mel(style_wav, ap, cuda=False): - style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) +def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): if cuda: - return style_mel.cuda() + device = "cuda" + style_mel = torch.FloatTensor( + ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device, + ).unsqueeze(0) return style_mel @@ -73,22 +75,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def id_to_torch(aux_id, cuda=False): +def id_to_torch(aux_id, cuda=False, device="cpu"): + if cuda: + device = "cuda" if aux_id is not None: aux_id = np.asarray(aux_id) - aux_id = torch.from_numpy(aux_id) - if cuda: - return aux_id.cuda() + aux_id = torch.from_numpy(aux_id).to(device) return aux_id -def embedding_to_torch(d_vector, cuda=False): +def embedding_to_torch(d_vector, cuda=False, device="cpu"): + if cuda: + device = "cuda" if d_vector is not None: d_vector = np.asarray(d_vector) d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) - d_vector = d_vector.squeeze().unsqueeze(0) - if cuda: - return d_vector.cuda() + d_vector = d_vector.squeeze().unsqueeze(0).to(device) return d_vector @@ -162,6 +164,11 @@ def synthesis( language_id (int): Language ID passed to the language embedding layer in multi-langual model. Defaults to None. """ + # device + device = next(model.parameters()).device + if use_cuda: + device = "cuda" + # GST or Capacitron processing # TODO: need to handle the case of setting both gst and capacitron to true somewhere style_mel = None @@ -169,10 +176,10 @@ def synthesis( if isinstance(style_wav, dict): style_mel = style_wav else: - style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, model.ap, device=device) if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None: - style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda) + style_mel = compute_style_mel(style_wav, model.ap, device=device) style_mel = style_mel.transpose(1, 2) # [1, time, depth] language_name = None @@ -188,26 +195,26 @@ def synthesis( ) # pass tensors to backend if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + speaker_id = id_to_torch(speaker_id, device=device) if d_vector is not None: - d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + d_vector = embedding_to_torch(d_vector, device=device) if language_id is not None: - language_id = id_to_torch(language_id, cuda=use_cuda) + language_id = id_to_torch(language_id, device=device) if not isinstance(style_mel, dict): # GST or Capacitron style mel - style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + style_mel = numpy_to_torch(style_mel, torch.float, device=device) if style_text is not None: style_text = np.asarray( model.tokenizer.text_to_ids(style_text, language=language_id), dtype=np.int32, ) - style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda) + style_text = numpy_to_torch(style_text, torch.long, device=device) style_text = style_text.unsqueeze(0) - text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda) + text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) text_inputs = text_inputs.unsqueeze(0) # synthesize voice outputs = run_model_torch( @@ -290,22 +297,27 @@ def transfer_voice( do_trim_silence (bool): trim silence after synthesis. Defaults to False. """ + # device + device = next(model.parameters()).device + if use_cuda: + device = "cuda" + # pass tensors to backend if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=use_cuda) + speaker_id = id_to_torch(speaker_id, device=device) if d_vector is not None: - d_vector = embedding_to_torch(d_vector, cuda=use_cuda) + d_vector = embedding_to_torch(d_vector, device=device) if reference_d_vector is not None: - reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda) + reference_d_vector = embedding_to_torch(reference_d_vector, device=device) # load reference_wav audio reference_wav = embedding_to_torch( model.ap.load_wav( reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate ), - cuda=use_cuda, + device=device, ) if hasattr(model, "module"): diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index bc0e231d..fbae3216 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -5,6 +5,7 @@ from typing import List import numpy as np import pysbd import torch +from torch import nn from TTS.config import load_config from TTS.tts.configs.vits_config import VitsConfig @@ -21,7 +22,7 @@ from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input -class Synthesizer(object): +class Synthesizer(nn.Module): def __init__( self, tts_checkpoint: str = "", @@ -60,6 +61,7 @@ class Synthesizer(object): vc_config (str, optional): path to the voice conversion config file. Defaults to `""`, use_cuda (bool, optional): enable/disable cuda. Defaults to False. """ + super().__init__() self.tts_checkpoint = tts_checkpoint self.tts_config_path = tts_config_path self.tts_speakers_file = tts_speakers_file @@ -356,7 +358,12 @@ class Synthesizer(object): if speaker_wav is not None and self.tts_model.speaker_manager is not None: speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav) + vocoder_device = "cpu" use_gl = self.vocoder_model is None + if not use_gl: + vocoder_device = next(self.vocoder_model.parameters()).device + if self.use_cuda: + vocoder_device = "cuda" if not reference_wav: # not voice conversion for sen in sens: @@ -388,7 +395,6 @@ class Synthesizer(object): mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy() # denormalize tts output based on tts audio config mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T - device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch @@ -403,8 +409,8 @@ class Synthesizer(object): vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable # run vocoder model # [1, T, C] - waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) - if self.use_cuda and not use_gl: + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl: waveform = waveform.cpu() if not use_gl: waveform = waveform.numpy() @@ -453,7 +459,6 @@ class Synthesizer(object): mel_postnet_spec = outputs[0].detach().cpu().numpy() # denormalize tts output based on tts audio config mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T - device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch @@ -468,8 +473,8 @@ class Synthesizer(object): vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable # run vocoder model # [1, T, C] - waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) - if self.use_cuda: + waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device)) + if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"): waveform = waveform.cpu() if not use_gl: waveform = waveform.numpy() From a96562a7506b3106d5b7a32500e5e0e9c24daf3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 26 Aug 2023 10:36:40 +0200 Subject: [PATCH 3/6] Update .models.json --- TTS/.models.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index c97c6a38..dfd008b4 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -15,7 +15,7 @@ "hf_url": [ "https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt", - "https://coqui.gateway.scarf.sh/hf/bark/text_2.pt", + "https://app.coqui.ai/tts_model/bark_text_2", "https://coqui.gateway.scarf.sh/hf/bark/config.json" ], "default_vocoder": null, @@ -238,7 +238,7 @@ "tortoise-v2": { "description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts", "github_rls_url": [ - "https://coqui.gateway.scarf.sh/v0.14.1_models/autoregressive.pth", + "https://app.coqui.ai/tts_model/tortoise", "https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth", "https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth", "https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth", @@ -879,4 +879,4 @@ } } } -} \ No newline at end of file +} From 04a36a727bc232e5efadbfcd5021d17ec42be68b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 26 Aug 2023 10:39:48 +0200 Subject: [PATCH 4/6] Bump up to v0.16.4 --- TTS/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/VERSION b/TTS/VERSION index 7eb3095a..5f2491c5 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.16.3 +0.16.4 From a7a96d08ddd2e1b0b2b8fd21a10d00075305bf7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 26 Aug 2023 11:59:00 +0200 Subject: [PATCH 5/6] Fix loading Bark (#2893) * Fixup hubert path * Make style --- TTS/.models.json | 8 +++++--- TTS/tts/models/bark.py | 6 ++++++ TTS/tts/utils/synthesis.py | 3 ++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/TTS/.models.json b/TTS/.models.json index dfd008b4..69ac7514 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -15,8 +15,10 @@ "hf_url": [ "https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt", - "https://app.coqui.ai/tts_model/bark_text_2", - "https://coqui.gateway.scarf.sh/hf/bark/config.json" + "https://app.coqui.ai/tts_model/text_2.pt", + "https://coqui.gateway.scarf.sh/hf/bark/config.json", + "https://coqui.gateway.scarf.sh/hf/bark/hubert.pt", + "https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth" ], "default_vocoder": null, "commit": "e9a1953e", @@ -238,7 +240,7 @@ "tortoise-v2": { "description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts", "github_rls_url": [ - "https://app.coqui.ai/tts_model/tortoise", + "https://app.coqui.ai/tts_model/autoregressive.pth", "https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth", "https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth", "https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth", diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py index f198c3d5..e5edffd4 100644 --- a/TTS/tts/models/bark.py +++ b/TTS/tts/models/bark.py @@ -246,6 +246,8 @@ class Bark(BaseTTS): text_model_path=None, coarse_model_path=None, fine_model_path=None, + hubert_model_path=None, + hubert_tokenizer_path=None, eval=False, strict=True, **kwargs, @@ -267,10 +269,14 @@ class Bark(BaseTTS): text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt") coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt") fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt") + hubert_model_path = hubert_model_path or os.path.join(checkpoint_dir, "hubert.pt") + hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth") self.config.LOCAL_MODEL_PATHS["text"] = text_model_path self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path + self.config.LOCAL_MODEL_PATHS["hubert"] = hubert_model_path + self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path self.load_bark_models() diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index d3c29f93..797151c2 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -18,7 +18,8 @@ def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): if cuda: device = "cuda" style_mel = torch.FloatTensor( - ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), device=device, + ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)), + device=device, ).unsqueeze(0) return style_mel From c0b5e6174962e3ca94bb5d0f3e7826987fe6bdd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sat, 26 Aug 2023 12:00:25 +0200 Subject: [PATCH 6/6] Bump up to v0.16.5 --- TTS/VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/VERSION b/TTS/VERSION index 5f2491c5..19270385 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.16.4 +0.16.5