From 8cb1433e6e4e14d1e144906c724dfca9a3ad34f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 9 Nov 2022 22:12:48 +0100 Subject: [PATCH] Cache fsspec downloads (#2132) * Cache fsspec downloaded files * Use diff paths for test * Make fsspec caching optional * Decom GPU docker tests * Make progress bar optional for better CI log * Check path local --- .github/workflows/docker.yaml | 2 +- TTS/bin/synthesize.py | 9 ++++++- TTS/encoder/models/base_encoder.py | 10 ++++++-- TTS/model.py | 7 ++++-- TTS/tts/models/align_tts.py | 4 +-- TTS/tts/models/base_tacotron.py | 7 +++--- TTS/tts/models/forward_tts.py | 5 ++-- TTS/tts/models/vits.py | 8 ++---- TTS/utils/io.py | 25 +++++++++++++++---- TTS/utils/manage.py | 17 ++++++++----- TTS/vocoder/models/gan.py | 3 ++- TTS/vocoder/models/hifigan_generator.py | 4 +-- TTS/vocoder/models/melgan_generator.py | 4 +-- .../models/parallel_wavegan_generator.py | 4 +-- TTS/vocoder/models/wavegrad.py | 4 +-- TTS/vocoder/models/wavernn.py | 4 +-- .../test_extract_tts_spectrograms.py | 6 ++--- tests/zoo_tests/test_models.py | 11 +++++--- 18 files changed, 86 insertions(+), 48 deletions(-) diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 7d383f3f..67e9cc0c 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -15,7 +15,7 @@ jobs: matrix: arch: ["amd64"] base: - - "nvcr.io/nvidia/pytorch:22.03-py3" # GPU enabled + # - "nvcr.io/nvidia/pytorch:22.03-py3" # GPU enabled - "ubuntu:20.04" # CPU only steps: - uses: actions/checkout@v2 diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 90ed9746..bbcb9c95 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -238,6 +238,13 @@ If you don't specify any models, then it uses LJSpeech based English model. help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).", default=None, ) + parser.add_argument( + "--progress_bar", + type=str2bool, + help="If true shows a progress bar for the model download. Defaults to True", + default=True, + ) + args = parser.parse_args() # print the description if either text or list_models is not set @@ -255,7 +262,7 @@ If you don't specify any models, then it uses LJSpeech based English model. # load model manager path = Path(__file__).parent / "../.models.json" - manager = ModelManager(path) + manager = ModelManager(path, progress_bar=args.progress_bar) model_path = None config_path = None diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py index f741a2de..957ea3c4 100644 --- a/TTS/encoder/models/base_encoder.py +++ b/TTS/encoder/models/base_encoder.py @@ -107,9 +107,15 @@ class BaseEncoder(nn.Module): return criterion def load_checkpoint( - self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None + self, + config: Coqpit, + checkpoint_path: str, + eval: bool = False, + use_cuda: bool = False, + criterion=None, + cache=False, ): - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) try: self.load_state_dict(state["model"]) print(" > Model fully restored. ") diff --git a/TTS/model.py b/TTS/model.py index a53b916a..ae6be7b4 100644 --- a/TTS/model.py +++ b/TTS/model.py @@ -44,13 +44,16 @@ class BaseTrainerModel(TrainerModel): return outputs_dict @abstractmethod - def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: + def load_checkpoint( + self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False + ) -> None: """Load a model checkpoint gile and get ready for training or inference. Args: config (Coqpit): Model configuration. checkpoint_path (str): Path to the model checkpoint file. eval (bool, optional): If true, init model for inference else for training. Defaults to False. - strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. + strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. + cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. """ ... diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index c1e2ffb3..4fdaa596 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -398,9 +398,9 @@ class AlignTTS(BaseTTS): logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( - self, config, checkpoint_path, eval=False + self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index c0f4c339..4aaf5261 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -92,16 +92,17 @@ class BaseTacotron(BaseTTS): pass def load_checkpoint( - self, config, checkpoint_path, eval=False + self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin """Load model checkpoint and set up internals. Args: config (Coqpi): model configuration. checkpoint_path (str): path to checkpoint file. - eval (bool): whether to load model for evaluation. + eval (bool, optional): whether to load model for evaluation. + cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False. """ - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) # TODO: set r in run-time by taking it from the new config if "r" in state: diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index a1273f7f..c1132df2 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -16,6 +16,7 @@ from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram +from TTS.utils.io import load_fsspec @dataclass @@ -707,9 +708,9 @@ class ForwardTTS(BaseTTS): logger.eval_audios(steps, audios, self.ap.sample_rate) def load_checkpoint( - self, config, checkpoint_path, eval=False + self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index c81d9203..4959d7ba 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -1686,14 +1686,10 @@ class Vits(BaseTTS): return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] def load_checkpoint( - self, - config, - checkpoint_path, - eval=False, - strict=True, + self, config, checkpoint_path, eval=False, strict=True, cache=False ): # pylint: disable=unused-argument, redefined-builtin """Load the model checkpoint and setup for training or inference""" - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) # compat band-aid for the pre-trained models to not use the encoder baked into the model # TODO: consider baking the speaker encoder into the model and call it from there. # as it is probably easier for model distribution. diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 0b32f77a..e9bdf3e6 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -9,6 +9,8 @@ import fsspec import torch from coqpit import Coqpit +from TTS.utils.generic_utils import get_user_data_dir + class RenamingUnpickler(pickle_tts.Unpickler): """Overload default pickler to solve module renaming problem""" @@ -57,6 +59,7 @@ def copy_model_files(config: Coqpit, out_path, new_fields=None): def load_fsspec( path: str, map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + cache: bool = True, **kwargs, ) -> Any: """Like torch.load but can load from other locations (e.g. s3:// , gs://). @@ -64,21 +67,33 @@ def load_fsspec( Args: path: Any path or url supported by fsspec. map_location: torch.device or str. + cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. **kwargs: Keyword arguments forwarded to torch.load. Returns: Object stored in path. """ - with fsspec.open(path, "rb") as f: - return torch.load(f, map_location=map_location, **kwargs) + is_local = os.path.isdir(path) or os.path.isfile(path) + if cache and not is_local: + with fsspec.open( + f"filecache::{path}", + filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, + mode="rb", + ) as f: + return torch.load(f, map_location=map_location, **kwargs) + else: + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location=map_location, **kwargs) -def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin +def load_checkpoint( + model, checkpoint_path, use_cuda=False, eval=False, cache=False +): # pylint: disable=redefined-builtin try: - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) except ModuleNotFoundError: pickle_tts.Unpickler = RenamingUnpickler - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) model.load_state_dict(state["model"]) if use_cuda: model.cuda() diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 5eed6683..645099e0 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -32,11 +32,14 @@ class ModelManager(object): home path. Args: - models_file (str): path to .model.json + models_file (str): path to .model.json file. Defaults to None. + output_prefix (str): prefix to `tts` to download models. Defaults to None + progress_bar (bool): print a progress bar when donwloading a file. Defaults to False. """ - def __init__(self, models_file=None, output_prefix=None): + def __init__(self, models_file=None, output_prefix=None, progress_bar=False): super().__init__() + self.progress_bar = progress_bar if output_prefix is None: self.output_prefix = get_user_data_dir("tts") else: @@ -236,7 +239,7 @@ class ModelManager(object): os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") # download from github release - self._download_zip_file(model_item["github_rls_url"], output_path) + self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar) self.print_model_license(model_item=model_item) # find downloaded files output_model_path, output_config_path = self._find_files(output_path) @@ -334,7 +337,7 @@ class ModelManager(object): config.save_json(config_path) @staticmethod - def _download_zip_file(file_url, output_folder): + def _download_zip_file(file_url, output_folder, progress_bar): """Download the github releases""" # download the file r = requests.get(file_url, stream=True) @@ -342,11 +345,13 @@ class ModelManager(object): try: total_size_in_bytes = int(r.headers.get("content-length", 0)) block_size = 1024 # 1 Kibibyte - progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) + if progress_bar: + progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) with open(temp_zip_name, "wb") as file: for data in r.iter_content(block_size): - progress_bar.update(len(data)) + if progress_bar: + progress_bar.update(len(data)) file.write(data) with zipfile.ZipFile(temp_zip_name) as z: z.extractall(output_folder) diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index a3803f77..19c30e98 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -231,6 +231,7 @@ class GAN(BaseVocoder): config: Coqpit, checkpoint_path: str, eval: bool = False, # pylint: disable=unused-argument, redefined-builtin + cache: bool = False, ) -> None: """Load a GAN checkpoint and initialize model parameters. @@ -239,7 +240,7 @@ class GAN(BaseVocoder): checkpoint_path (str): Checkpoint file path. eval (bool, optional): If true, load the model for inference. If falseDefaults to False. """ - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) # band-aid for older than v0.0.15 GAN models if "model_disc" in state: self.model_g.load_checkpoint(config, checkpoint_path, eval) diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index fc15f3af..7c6ad9b6 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -290,9 +290,9 @@ class HifiganGenerator(torch.nn.Module): remove_weight_norm(self.conv_post) def load_checkpoint( - self, config, checkpoint_path, eval=False + self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index 80b47870..989797f0 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -85,9 +85,9 @@ class MelganGenerator(nn.Module): layer.remove_weight_norm() def load_checkpoint( - self, config, checkpoint_path, eval=False + self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index ee9d8ad5..c741774a 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -153,9 +153,9 @@ class ParallelWaveganGenerator(torch.nn.Module): return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) def load_checkpoint( - self, config, checkpoint_path, eval=False + self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index c4968f1f..a0f9221a 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -218,9 +218,9 @@ class Wavegrad(BaseVocoder): self.y_conv = weight_norm(self.y_conv) def load_checkpoint( - self, config, checkpoint_path, eval=False + self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index e0a25e32..0ea6b6e0 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -542,9 +542,9 @@ class Wavernn(BaseVocoder): return unfolded def load_checkpoint( - self, config, checkpoint_path, eval=False + self, config, checkpoint_path, eval=False, cache=False ): # pylint: disable=unused-argument, redefined-builtin - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) self.load_state_dict(state["model"]) if eval: self.eval() diff --git a/tests/aux_tests/test_extract_tts_spectrograms.py b/tests/aux_tests/test_extract_tts_spectrograms.py index ef751846..f9392706 100644 --- a/tests/aux_tests/test_extract_tts_spectrograms.py +++ b/tests/aux_tests/test_extract_tts_spectrograms.py @@ -15,7 +15,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase): def test_GlowTTS(): # set paths config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json") - checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth") + checkpoint_path = os.path.join(get_tests_output_path(), "glowtts.pth") output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/") # load config c = load_config(config_path) @@ -33,7 +33,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase): def test_Tacotron2(): # set paths config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json") - checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth") + checkpoint_path = os.path.join(get_tests_output_path(), "tacotron2.pth") output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/") # load config c = load_config(config_path) @@ -51,7 +51,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase): def test_Tacotron(): # set paths config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json") - checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth") + checkpoint_path = os.path.join(get_tests_output_path(), "tacotron.pth") output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/") # load config c = load_config(config_path) diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index d3c7c54b..7105edf4 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -15,7 +15,7 @@ def test_run_all_models(): print(" > Run synthesizer with all the models.") download_dir = get_user_data_dir("tts") output_path = os.path.join(get_tests_output_path(), "output.wav") - manager = ModelManager(output_prefix=get_tests_output_path()) + manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False) model_names = manager.list_models() for model_name in model_names: print(f"\n > Run - {model_name}") @@ -41,11 +41,14 @@ def test_run_all_models(): speaker_id = list(speaker_manager.name_to_id.keys())[0] run_cli( f"tts --model_name {model_name} " - f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" ' + f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --progress_bar False' ) else: # single-speaker model - run_cli(f"tts --model_name {model_name} " f'--text "This is an example." --out_path "{output_path}"') + run_cli( + f"tts --model_name {model_name} " + f'--text "This is an example." --out_path "{output_path}" --progress_bar False' + ) # remove downloaded models shutil.rmtree(download_dir) else: @@ -67,5 +70,5 @@ def test_voice_conversion(): output_path = os.path.join(get_tests_output_path(), "output.wav") run_cli( f"tts --model_name {model_name}" - f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} " + f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} --progress_bar False" )