mirror of https://github.com/coqui-ai/TTS.git
commit
530a8939fe
|
@ -15,8 +15,10 @@
|
||||||
"hf_url": [
|
"hf_url": [
|
||||||
"https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt",
|
"https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt",
|
||||||
"https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt",
|
"https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt",
|
||||||
"https://coqui.gateway.scarf.sh/hf/bark/text_2.pt",
|
"https://app.coqui.ai/tts_model/text_2.pt",
|
||||||
"https://coqui.gateway.scarf.sh/hf/bark/config.json"
|
"https://coqui.gateway.scarf.sh/hf/bark/config.json",
|
||||||
|
"https://coqui.gateway.scarf.sh/hf/bark/hubert.pt",
|
||||||
|
"https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth"
|
||||||
],
|
],
|
||||||
"default_vocoder": null,
|
"default_vocoder": null,
|
||||||
"commit": "e9a1953e",
|
"commit": "e9a1953e",
|
||||||
|
@ -238,7 +240,7 @@
|
||||||
"tortoise-v2": {
|
"tortoise-v2": {
|
||||||
"description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts",
|
"description": "Tortoise tts model https://github.com/neonbjb/tortoise-tts",
|
||||||
"github_rls_url": [
|
"github_rls_url": [
|
||||||
"https://coqui.gateway.scarf.sh/v0.14.1_models/autoregressive.pth",
|
"https://app.coqui.ai/tts_model/autoregressive.pth",
|
||||||
"https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth",
|
"https://coqui.gateway.scarf.sh/v0.14.1_models/clvp2.pth",
|
||||||
"https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth",
|
"https://coqui.gateway.scarf.sh/v0.14.1_models/cvvp.pth",
|
||||||
"https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth",
|
"https://coqui.gateway.scarf.sh/v0.14.1_models/diffusion_decoder.pth",
|
||||||
|
@ -879,4 +881,4 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
0.16.3
|
0.16.5
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from TTS.cs_api import CS_API
|
from TTS.cs_api import CS_API
|
||||||
from TTS.utils.audio.numpy_transforms import save_wav
|
from TTS.utils.audio.numpy_transforms import save_wav
|
||||||
|
@ -10,7 +12,7 @@ from TTS.utils.manage import ModelManager
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
|
|
||||||
|
|
||||||
class TTS:
|
class TTS(nn.Module):
|
||||||
"""TODO: Add voice conversion and Capacitron support."""
|
"""TODO: Add voice conversion and Capacitron support."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -62,6 +64,7 @@ class TTS:
|
||||||
Defaults to "XTTS".
|
Defaults to "XTTS".
|
||||||
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
|
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)
|
||||||
|
|
||||||
self.synthesizer = None
|
self.synthesizer = None
|
||||||
|
@ -70,6 +73,9 @@ class TTS:
|
||||||
self.cs_api_model = cs_api_model
|
self.cs_api_model = cs_api_model
|
||||||
self.model_name = None
|
self.model_name = None
|
||||||
|
|
||||||
|
if gpu:
|
||||||
|
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
|
||||||
|
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
if "tts_models" in model_name or "coqui_studio" in model_name:
|
if "tts_models" in model_name or "coqui_studio" in model_name:
|
||||||
self.load_tts_model_by_name(model_name, gpu)
|
self.load_tts_model_by_name(model_name, gpu)
|
||||||
|
|
|
@ -246,6 +246,8 @@ class Bark(BaseTTS):
|
||||||
text_model_path=None,
|
text_model_path=None,
|
||||||
coarse_model_path=None,
|
coarse_model_path=None,
|
||||||
fine_model_path=None,
|
fine_model_path=None,
|
||||||
|
hubert_model_path=None,
|
||||||
|
hubert_tokenizer_path=None,
|
||||||
eval=False,
|
eval=False,
|
||||||
strict=True,
|
strict=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@ -267,10 +269,14 @@ class Bark(BaseTTS):
|
||||||
text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt")
|
text_model_path = text_model_path or os.path.join(checkpoint_dir, "text_2.pt")
|
||||||
coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt")
|
coarse_model_path = coarse_model_path or os.path.join(checkpoint_dir, "coarse_2.pt")
|
||||||
fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt")
|
fine_model_path = fine_model_path or os.path.join(checkpoint_dir, "fine_2.pt")
|
||||||
|
hubert_model_path = hubert_model_path or os.path.join(checkpoint_dir, "hubert.pt")
|
||||||
|
hubert_tokenizer_path = hubert_tokenizer_path or os.path.join(checkpoint_dir, "tokenizer.pth")
|
||||||
|
|
||||||
self.config.LOCAL_MODEL_PATHS["text"] = text_model_path
|
self.config.LOCAL_MODEL_PATHS["text"] = text_model_path
|
||||||
self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path
|
self.config.LOCAL_MODEL_PATHS["coarse"] = coarse_model_path
|
||||||
self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path
|
self.config.LOCAL_MODEL_PATHS["fine"] = fine_model_path
|
||||||
|
self.config.LOCAL_MODEL_PATHS["hubert"] = hubert_model_path
|
||||||
|
self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"] = hubert_tokenizer_path
|
||||||
|
|
||||||
self.load_bark_models()
|
self.load_bark_models()
|
||||||
|
|
||||||
|
|
|
@ -5,19 +5,22 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
def numpy_to_torch(np_array, dtype, cuda=False):
|
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
|
||||||
|
if cuda:
|
||||||
|
device = "cuda"
|
||||||
if np_array is None:
|
if np_array is None:
|
||||||
return None
|
return None
|
||||||
tensor = torch.as_tensor(np_array, dtype=dtype)
|
tensor = torch.as_tensor(np_array, dtype=dtype, device=device)
|
||||||
if cuda:
|
|
||||||
return tensor.cuda()
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def compute_style_mel(style_wav, ap, cuda=False):
|
def compute_style_mel(style_wav, ap, cuda=False, device="cpu"):
|
||||||
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
|
|
||||||
if cuda:
|
if cuda:
|
||||||
return style_mel.cuda()
|
device = "cuda"
|
||||||
|
style_mel = torch.FloatTensor(
|
||||||
|
ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate)),
|
||||||
|
device=device,
|
||||||
|
).unsqueeze(0)
|
||||||
return style_mel
|
return style_mel
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,22 +76,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
|
|
||||||
def id_to_torch(aux_id, cuda=False):
|
def id_to_torch(aux_id, cuda=False, device="cpu"):
|
||||||
|
if cuda:
|
||||||
|
device = "cuda"
|
||||||
if aux_id is not None:
|
if aux_id is not None:
|
||||||
aux_id = np.asarray(aux_id)
|
aux_id = np.asarray(aux_id)
|
||||||
aux_id = torch.from_numpy(aux_id)
|
aux_id = torch.from_numpy(aux_id).to(device)
|
||||||
if cuda:
|
|
||||||
return aux_id.cuda()
|
|
||||||
return aux_id
|
return aux_id
|
||||||
|
|
||||||
|
|
||||||
def embedding_to_torch(d_vector, cuda=False):
|
def embedding_to_torch(d_vector, cuda=False, device="cpu"):
|
||||||
|
if cuda:
|
||||||
|
device = "cuda"
|
||||||
if d_vector is not None:
|
if d_vector is not None:
|
||||||
d_vector = np.asarray(d_vector)
|
d_vector = np.asarray(d_vector)
|
||||||
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor)
|
||||||
d_vector = d_vector.squeeze().unsqueeze(0)
|
d_vector = d_vector.squeeze().unsqueeze(0).to(device)
|
||||||
if cuda:
|
|
||||||
return d_vector.cuda()
|
|
||||||
return d_vector
|
return d_vector
|
||||||
|
|
||||||
|
|
||||||
|
@ -162,6 +165,11 @@ def synthesis(
|
||||||
language_id (int):
|
language_id (int):
|
||||||
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
|
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
# device
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
if use_cuda:
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
# GST or Capacitron processing
|
# GST or Capacitron processing
|
||||||
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
|
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
|
||||||
style_mel = None
|
style_mel = None
|
||||||
|
@ -169,10 +177,10 @@ def synthesis(
|
||||||
if isinstance(style_wav, dict):
|
if isinstance(style_wav, dict):
|
||||||
style_mel = style_wav
|
style_mel = style_wav
|
||||||
else:
|
else:
|
||||||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
|
style_mel = compute_style_mel(style_wav, model.ap, device=device)
|
||||||
|
|
||||||
if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
|
if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
|
||||||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
|
style_mel = compute_style_mel(style_wav, model.ap, device=device)
|
||||||
style_mel = style_mel.transpose(1, 2) # [1, time, depth]
|
style_mel = style_mel.transpose(1, 2) # [1, time, depth]
|
||||||
|
|
||||||
language_name = None
|
language_name = None
|
||||||
|
@ -188,26 +196,26 @@ def synthesis(
|
||||||
)
|
)
|
||||||
# pass tensors to backend
|
# pass tensors to backend
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
speaker_id = id_to_torch(speaker_id, device=device)
|
||||||
|
|
||||||
if d_vector is not None:
|
if d_vector is not None:
|
||||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
d_vector = embedding_to_torch(d_vector, device=device)
|
||||||
|
|
||||||
if language_id is not None:
|
if language_id is not None:
|
||||||
language_id = id_to_torch(language_id, cuda=use_cuda)
|
language_id = id_to_torch(language_id, device=device)
|
||||||
|
|
||||||
if not isinstance(style_mel, dict):
|
if not isinstance(style_mel, dict):
|
||||||
# GST or Capacitron style mel
|
# GST or Capacitron style mel
|
||||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
style_mel = numpy_to_torch(style_mel, torch.float, device=device)
|
||||||
if style_text is not None:
|
if style_text is not None:
|
||||||
style_text = np.asarray(
|
style_text = np.asarray(
|
||||||
model.tokenizer.text_to_ids(style_text, language=language_id),
|
model.tokenizer.text_to_ids(style_text, language=language_id),
|
||||||
dtype=np.int32,
|
dtype=np.int32,
|
||||||
)
|
)
|
||||||
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda)
|
style_text = numpy_to_torch(style_text, torch.long, device=device)
|
||||||
style_text = style_text.unsqueeze(0)
|
style_text = style_text.unsqueeze(0)
|
||||||
|
|
||||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
text_inputs = numpy_to_torch(text_inputs, torch.long, device=device)
|
||||||
text_inputs = text_inputs.unsqueeze(0)
|
text_inputs = text_inputs.unsqueeze(0)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
outputs = run_model_torch(
|
outputs = run_model_torch(
|
||||||
|
@ -290,22 +298,27 @@ def transfer_voice(
|
||||||
do_trim_silence (bool):
|
do_trim_silence (bool):
|
||||||
trim silence after synthesis. Defaults to False.
|
trim silence after synthesis. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
# device
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
if use_cuda:
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
# pass tensors to backend
|
# pass tensors to backend
|
||||||
if speaker_id is not None:
|
if speaker_id is not None:
|
||||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
speaker_id = id_to_torch(speaker_id, device=device)
|
||||||
|
|
||||||
if d_vector is not None:
|
if d_vector is not None:
|
||||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
d_vector = embedding_to_torch(d_vector, device=device)
|
||||||
|
|
||||||
if reference_d_vector is not None:
|
if reference_d_vector is not None:
|
||||||
reference_d_vector = embedding_to_torch(reference_d_vector, cuda=use_cuda)
|
reference_d_vector = embedding_to_torch(reference_d_vector, device=device)
|
||||||
|
|
||||||
# load reference_wav audio
|
# load reference_wav audio
|
||||||
reference_wav = embedding_to_torch(
|
reference_wav = embedding_to_torch(
|
||||||
model.ap.load_wav(
|
model.ap.load_wav(
|
||||||
reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate
|
reference_wav, sr=model.args.encoder_sample_rate if model.args.encoder_sample_rate else model.ap.sample_rate
|
||||||
),
|
),
|
||||||
cuda=use_cuda,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(model, "module"):
|
if hasattr(model, "module"):
|
||||||
|
|
|
@ -126,7 +126,13 @@ def get_import_path(obj: object) -> str:
|
||||||
|
|
||||||
|
|
||||||
def get_user_data_dir(appname):
|
def get_user_data_dir(appname):
|
||||||
if sys.platform == "win32":
|
TTS_HOME = os.environ.get("TTS_HOME")
|
||||||
|
XDG_DATA_HOME = os.environ.get("XDG_DATA_HOME")
|
||||||
|
if TTS_HOME is not None:
|
||||||
|
ans = Path(TTS_HOME).expanduser().resolve(strict=False)
|
||||||
|
elif XDG_DATA_HOME is not None:
|
||||||
|
ans = Path(XDG_DATA_HOME).expanduser().resolve(strict=False)
|
||||||
|
elif sys.platform == "win32":
|
||||||
import winreg # pylint: disable=import-outside-toplevel
|
import winreg # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
key = winreg.OpenKey(
|
key = winreg.OpenKey(
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pysbd
|
import pysbd
|
||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
@ -21,7 +22,7 @@ from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||||
|
|
||||||
|
|
||||||
class Synthesizer(object):
|
class Synthesizer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tts_checkpoint: str = "",
|
tts_checkpoint: str = "",
|
||||||
|
@ -60,6 +61,7 @@ class Synthesizer(object):
|
||||||
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
|
vc_config (str, optional): path to the voice conversion config file. Defaults to `""`,
|
||||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
self.tts_checkpoint = tts_checkpoint
|
self.tts_checkpoint = tts_checkpoint
|
||||||
self.tts_config_path = tts_config_path
|
self.tts_config_path = tts_config_path
|
||||||
self.tts_speakers_file = tts_speakers_file
|
self.tts_speakers_file = tts_speakers_file
|
||||||
|
@ -356,7 +358,12 @@ class Synthesizer(object):
|
||||||
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
|
if speaker_wav is not None and self.tts_model.speaker_manager is not None:
|
||||||
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
speaker_embedding = self.tts_model.speaker_manager.compute_embedding_from_clip(speaker_wav)
|
||||||
|
|
||||||
|
vocoder_device = "cpu"
|
||||||
use_gl = self.vocoder_model is None
|
use_gl = self.vocoder_model is None
|
||||||
|
if not use_gl:
|
||||||
|
vocoder_device = next(self.vocoder_model.parameters()).device
|
||||||
|
if self.use_cuda:
|
||||||
|
vocoder_device = "cuda"
|
||||||
|
|
||||||
if not reference_wav: # not voice conversion
|
if not reference_wav: # not voice conversion
|
||||||
for sen in sens:
|
for sen in sens:
|
||||||
|
@ -388,7 +395,6 @@ class Synthesizer(object):
|
||||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
||||||
# denormalize tts output based on tts audio config
|
# denormalize tts output based on tts audio config
|
||||||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||||
device_type = "cuda" if self.use_cuda else "cpu"
|
|
||||||
# renormalize spectrogram based on vocoder config
|
# renormalize spectrogram based on vocoder config
|
||||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||||
# compute scale factor for possible sample rate mismatch
|
# compute scale factor for possible sample rate mismatch
|
||||||
|
@ -403,8 +409,8 @@ class Synthesizer(object):
|
||||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||||
# run vocoder model
|
# run vocoder model
|
||||||
# [1, T, C]
|
# [1, T, C]
|
||||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
|
||||||
if self.use_cuda and not use_gl:
|
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu") and not use_gl:
|
||||||
waveform = waveform.cpu()
|
waveform = waveform.cpu()
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
waveform = waveform.numpy()
|
waveform = waveform.numpy()
|
||||||
|
@ -453,7 +459,6 @@ class Synthesizer(object):
|
||||||
mel_postnet_spec = outputs[0].detach().cpu().numpy()
|
mel_postnet_spec = outputs[0].detach().cpu().numpy()
|
||||||
# denormalize tts output based on tts audio config
|
# denormalize tts output based on tts audio config
|
||||||
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
mel_postnet_spec = self.tts_model.ap.denormalize(mel_postnet_spec.T).T
|
||||||
device_type = "cuda" if self.use_cuda else "cpu"
|
|
||||||
# renormalize spectrogram based on vocoder config
|
# renormalize spectrogram based on vocoder config
|
||||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||||
# compute scale factor for possible sample rate mismatch
|
# compute scale factor for possible sample rate mismatch
|
||||||
|
@ -468,8 +473,8 @@ class Synthesizer(object):
|
||||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||||
# run vocoder model
|
# run vocoder model
|
||||||
# [1, T, C]
|
# [1, T, C]
|
||||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
waveform = self.vocoder_model.inference(vocoder_input.to(vocoder_device))
|
||||||
if self.use_cuda:
|
if torch.is_tensor(waveform) and waveform.device != torch.device("cpu"):
|
||||||
waveform = waveform.cpu()
|
waveform = waveform.cpu()
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
waveform = waveform.numpy()
|
waveform = waveform.numpy()
|
||||||
|
|
Loading…
Reference in New Issue