mirror of https://github.com/coqui-ai/TTS.git
Cache fsspec downloads (#2132)
* Cache fsspec downloaded files * Use diff paths for test * Make fsspec caching optional * Decom GPU docker tests * Make progress bar optional for better CI log * Check path local
This commit is contained in:
parent
c5412532ac
commit
8cb1433e6e
|
@ -15,7 +15,7 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
arch: ["amd64"]
|
arch: ["amd64"]
|
||||||
base:
|
base:
|
||||||
- "nvcr.io/nvidia/pytorch:22.03-py3" # GPU enabled
|
# - "nvcr.io/nvidia/pytorch:22.03-py3" # GPU enabled
|
||||||
- "ubuntu:20.04" # CPU only
|
- "ubuntu:20.04" # CPU only
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v2
|
||||||
|
|
|
@ -238,6 +238,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).",
|
help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--progress_bar",
|
||||||
|
type=str2bool,
|
||||||
|
help="If true shows a progress bar for the model download. Defaults to True",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# print the description if either text or list_models is not set
|
# print the description if either text or list_models is not set
|
||||||
|
@ -255,7 +262,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
||||||
|
|
||||||
# load model manager
|
# load model manager
|
||||||
path = Path(__file__).parent / "../.models.json"
|
path = Path(__file__).parent / "../.models.json"
|
||||||
manager = ModelManager(path)
|
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||||
|
|
||||||
model_path = None
|
model_path = None
|
||||||
config_path = None
|
config_path = None
|
||||||
|
|
|
@ -107,9 +107,15 @@ class BaseEncoder(nn.Module):
|
||||||
return criterion
|
return criterion
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
checkpoint_path: str,
|
||||||
|
eval: bool = False,
|
||||||
|
use_cuda: bool = False,
|
||||||
|
criterion=None,
|
||||||
|
cache=False,
|
||||||
):
|
):
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
try:
|
try:
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
print(" > Model fully restored. ")
|
print(" > Model fully restored. ")
|
||||||
|
|
|
@ -44,13 +44,16 @@ class BaseTrainerModel(TrainerModel):
|
||||||
return outputs_dict
|
return outputs_dict
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
|
def load_checkpoint(
|
||||||
|
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
|
||||||
|
) -> None:
|
||||||
"""Load a model checkpoint gile and get ready for training or inference.
|
"""Load a model checkpoint gile and get ready for training or inference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
checkpoint_path (str): Path to the model checkpoint file.
|
checkpoint_path (str): Path to the model checkpoint file.
|
||||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||||
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||||
|
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
@ -398,9 +398,9 @@ class AlignTTS(BaseTTS):
|
||||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -92,16 +92,17 @@ class BaseTacotron(BaseTTS):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
"""Load model checkpoint and set up internals.
|
"""Load model checkpoint and set up internals.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config (Coqpi): model configuration.
|
config (Coqpi): model configuration.
|
||||||
checkpoint_path (str): path to checkpoint file.
|
checkpoint_path (str): path to checkpoint file.
|
||||||
eval (bool): whether to load model for evaluation.
|
eval (bool, optional): whether to load model for evaluation.
|
||||||
|
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||||
"""
|
"""
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
# TODO: set r in run-time by taking it from the new config
|
# TODO: set r in run-time by taking it from the new config
|
||||||
if "r" in state:
|
if "r" in state:
|
||||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum
|
||||||
from TTS.tts.utils.speakers import SpeakerManager
|
from TTS.tts.utils.speakers import SpeakerManager
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -707,9 +708,9 @@ class ForwardTTS(BaseTTS):
|
||||||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -1686,14 +1686,10 @@ class Vits(BaseTTS):
|
||||||
return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)]
|
return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)]
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self,
|
self, config, checkpoint_path, eval=False, strict=True, cache=False
|
||||||
config,
|
|
||||||
checkpoint_path,
|
|
||||||
eval=False,
|
|
||||||
strict=True,
|
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
"""Load the model checkpoint and setup for training or inference"""
|
"""Load the model checkpoint and setup for training or inference"""
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||||
# as it is probably easier for model distribution.
|
# as it is probably easier for model distribution.
|
||||||
|
|
|
@ -9,6 +9,8 @@ import fsspec
|
||||||
import torch
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
|
from TTS.utils.generic_utils import get_user_data_dir
|
||||||
|
|
||||||
|
|
||||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||||
"""Overload default pickler to solve module renaming problem"""
|
"""Overload default pickler to solve module renaming problem"""
|
||||||
|
@ -57,6 +59,7 @@ def copy_model_files(config: Coqpit, out_path, new_fields=None):
|
||||||
def load_fsspec(
|
def load_fsspec(
|
||||||
path: str,
|
path: str,
|
||||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||||
|
cache: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||||
|
@ -64,21 +67,33 @@ def load_fsspec(
|
||||||
Args:
|
Args:
|
||||||
path: Any path or url supported by fsspec.
|
path: Any path or url supported by fsspec.
|
||||||
map_location: torch.device or str.
|
map_location: torch.device or str.
|
||||||
|
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
|
||||||
**kwargs: Keyword arguments forwarded to torch.load.
|
**kwargs: Keyword arguments forwarded to torch.load.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Object stored in path.
|
Object stored in path.
|
||||||
"""
|
"""
|
||||||
with fsspec.open(path, "rb") as f:
|
is_local = os.path.isdir(path) or os.path.isfile(path)
|
||||||
return torch.load(f, map_location=map_location, **kwargs)
|
if cache and not is_local:
|
||||||
|
with fsspec.open(
|
||||||
|
f"filecache::{path}",
|
||||||
|
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
|
||||||
|
mode="rb",
|
||||||
|
) as f:
|
||||||
|
return torch.load(f, map_location=map_location, **kwargs)
|
||||||
|
else:
|
||||||
|
with fsspec.open(path, "rb") as f:
|
||||||
|
return torch.load(f, map_location=map_location, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
def load_checkpoint(
|
||||||
|
model, checkpoint_path, use_cuda=False, eval=False, cache=False
|
||||||
|
): # pylint: disable=redefined-builtin
|
||||||
try:
|
try:
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pickle_tts.Unpickler = RenamingUnpickler
|
pickle_tts.Unpickler = RenamingUnpickler
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
|
||||||
model.load_state_dict(state["model"])
|
model.load_state_dict(state["model"])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
|
@ -32,11 +32,14 @@ class ModelManager(object):
|
||||||
home path.
|
home path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
models_file (str): path to .model.json
|
models_file (str): path to .model.json file. Defaults to None.
|
||||||
|
output_prefix (str): prefix to `tts` to download models. Defaults to None
|
||||||
|
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, models_file=None, output_prefix=None):
|
def __init__(self, models_file=None, output_prefix=None, progress_bar=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.progress_bar = progress_bar
|
||||||
if output_prefix is None:
|
if output_prefix is None:
|
||||||
self.output_prefix = get_user_data_dir("tts")
|
self.output_prefix = get_user_data_dir("tts")
|
||||||
else:
|
else:
|
||||||
|
@ -236,7 +239,7 @@ class ModelManager(object):
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
print(f" > Downloading model to {output_path}")
|
print(f" > Downloading model to {output_path}")
|
||||||
# download from github release
|
# download from github release
|
||||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||||
self.print_model_license(model_item=model_item)
|
self.print_model_license(model_item=model_item)
|
||||||
# find downloaded files
|
# find downloaded files
|
||||||
output_model_path, output_config_path = self._find_files(output_path)
|
output_model_path, output_config_path = self._find_files(output_path)
|
||||||
|
@ -334,7 +337,7 @@ class ModelManager(object):
|
||||||
config.save_json(config_path)
|
config.save_json(config_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _download_zip_file(file_url, output_folder):
|
def _download_zip_file(file_url, output_folder, progress_bar):
|
||||||
"""Download the github releases"""
|
"""Download the github releases"""
|
||||||
# download the file
|
# download the file
|
||||||
r = requests.get(file_url, stream=True)
|
r = requests.get(file_url, stream=True)
|
||||||
|
@ -342,11 +345,13 @@ class ModelManager(object):
|
||||||
try:
|
try:
|
||||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||||
block_size = 1024 # 1 Kibibyte
|
block_size = 1024 # 1 Kibibyte
|
||||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
if progress_bar:
|
||||||
|
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||||
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||||
with open(temp_zip_name, "wb") as file:
|
with open(temp_zip_name, "wb") as file:
|
||||||
for data in r.iter_content(block_size):
|
for data in r.iter_content(block_size):
|
||||||
progress_bar.update(len(data))
|
if progress_bar:
|
||||||
|
progress_bar.update(len(data))
|
||||||
file.write(data)
|
file.write(data)
|
||||||
with zipfile.ZipFile(temp_zip_name) as z:
|
with zipfile.ZipFile(temp_zip_name) as z:
|
||||||
z.extractall(output_folder)
|
z.extractall(output_folder)
|
||||||
|
|
|
@ -231,6 +231,7 @@ class GAN(BaseVocoder):
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
|
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
|
||||||
|
cache: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load a GAN checkpoint and initialize model parameters.
|
"""Load a GAN checkpoint and initialize model parameters.
|
||||||
|
|
||||||
|
@ -239,7 +240,7 @@ class GAN(BaseVocoder):
|
||||||
checkpoint_path (str): Checkpoint file path.
|
checkpoint_path (str): Checkpoint file path.
|
||||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||||
"""
|
"""
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
# band-aid for older than v0.0.15 GAN models
|
# band-aid for older than v0.0.15 GAN models
|
||||||
if "model_disc" in state:
|
if "model_disc" in state:
|
||||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||||
|
|
|
@ -290,9 +290,9 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
remove_weight_norm(self.conv_post)
|
remove_weight_norm(self.conv_post)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -85,9 +85,9 @@ class MelganGenerator(nn.Module):
|
||||||
layer.remove_weight_norm()
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -153,9 +153,9 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -218,9 +218,9 @@ class Wavegrad(BaseVocoder):
|
||||||
self.y_conv = weight_norm(self.y_conv)
|
self.y_conv = weight_norm(self.y_conv)
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -542,9 +542,9 @@ class Wavernn(BaseVocoder):
|
||||||
return unfolded
|
return unfolded
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(
|
||||||
self, config, checkpoint_path, eval=False
|
self, config, checkpoint_path, eval=False, cache=False
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
|
@ -15,7 +15,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
||||||
def test_GlowTTS():
|
def test_GlowTTS():
|
||||||
# set paths
|
# set paths
|
||||||
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
|
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
|
||||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth")
|
checkpoint_path = os.path.join(get_tests_output_path(), "glowtts.pth")
|
||||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||||
# load config
|
# load config
|
||||||
c = load_config(config_path)
|
c = load_config(config_path)
|
||||||
|
@ -33,7 +33,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
||||||
def test_Tacotron2():
|
def test_Tacotron2():
|
||||||
# set paths
|
# set paths
|
||||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
|
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
|
||||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth")
|
checkpoint_path = os.path.join(get_tests_output_path(), "tacotron2.pth")
|
||||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||||
# load config
|
# load config
|
||||||
c = load_config(config_path)
|
c = load_config(config_path)
|
||||||
|
@ -51,7 +51,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
||||||
def test_Tacotron():
|
def test_Tacotron():
|
||||||
# set paths
|
# set paths
|
||||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
|
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
|
||||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth")
|
checkpoint_path = os.path.join(get_tests_output_path(), "tacotron.pth")
|
||||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||||
# load config
|
# load config
|
||||||
c = load_config(config_path)
|
c = load_config(config_path)
|
||||||
|
|
|
@ -15,7 +15,7 @@ def test_run_all_models():
|
||||||
print(" > Run synthesizer with all the models.")
|
print(" > Run synthesizer with all the models.")
|
||||||
download_dir = get_user_data_dir("tts")
|
download_dir = get_user_data_dir("tts")
|
||||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
manager = ModelManager(output_prefix=get_tests_output_path())
|
manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
|
||||||
model_names = manager.list_models()
|
model_names = manager.list_models()
|
||||||
for model_name in model_names:
|
for model_name in model_names:
|
||||||
print(f"\n > Run - {model_name}")
|
print(f"\n > Run - {model_name}")
|
||||||
|
@ -41,11 +41,14 @@ def test_run_all_models():
|
||||||
speaker_id = list(speaker_manager.name_to_id.keys())[0]
|
speaker_id = list(speaker_manager.name_to_id.keys())[0]
|
||||||
run_cli(
|
run_cli(
|
||||||
f"tts --model_name {model_name} "
|
f"tts --model_name {model_name} "
|
||||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --progress_bar False'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# single-speaker model
|
# single-speaker model
|
||||||
run_cli(f"tts --model_name {model_name} " f'--text "This is an example." --out_path "{output_path}"')
|
run_cli(
|
||||||
|
f"tts --model_name {model_name} "
|
||||||
|
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
|
||||||
|
)
|
||||||
# remove downloaded models
|
# remove downloaded models
|
||||||
shutil.rmtree(download_dir)
|
shutil.rmtree(download_dir)
|
||||||
else:
|
else:
|
||||||
|
@ -67,5 +70,5 @@ def test_voice_conversion():
|
||||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||||
run_cli(
|
run_cli(
|
||||||
f"tts --model_name {model_name}"
|
f"tts --model_name {model_name}"
|
||||||
f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} "
|
f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} --progress_bar False"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue