Cache fsspec downloads (#2132)

* Cache fsspec downloaded files

* Use diff paths for test

* Make fsspec caching optional

* Decom GPU docker tests

* Make progress bar optional for better CI log

* Check path local
This commit is contained in:
Eren Gölge 2022-11-09 22:12:48 +01:00 committed by GitHub
parent c5412532ac
commit 8cb1433e6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 86 additions and 48 deletions

View File

@ -15,7 +15,7 @@ jobs:
matrix: matrix:
arch: ["amd64"] arch: ["amd64"]
base: base:
- "nvcr.io/nvidia/pytorch:22.03-py3" # GPU enabled # - "nvcr.io/nvidia/pytorch:22.03-py3" # GPU enabled
- "ubuntu:20.04" # CPU only - "ubuntu:20.04" # CPU only
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@ -238,6 +238,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).", help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).",
default=None, default=None,
) )
parser.add_argument(
"--progress_bar",
type=str2bool,
help="If true shows a progress bar for the model download. Defaults to True",
default=True,
)
args = parser.parse_args() args = parser.parse_args()
# print the description if either text or list_models is not set # print the description if either text or list_models is not set
@ -255,7 +262,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
# load model manager # load model manager
path = Path(__file__).parent / "../.models.json" path = Path(__file__).parent / "../.models.json"
manager = ModelManager(path) manager = ModelManager(path, progress_bar=args.progress_bar)
model_path = None model_path = None
config_path = None config_path = None

View File

@ -107,9 +107,15 @@ class BaseEncoder(nn.Module):
return criterion return criterion
def load_checkpoint( def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None self,
config: Coqpit,
checkpoint_path: str,
eval: bool = False,
use_cuda: bool = False,
criterion=None,
cache=False,
): ):
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
try: try:
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
print(" > Model fully restored. ") print(" > Model fully restored. ")

View File

@ -44,13 +44,16 @@ class BaseTrainerModel(TrainerModel):
return outputs_dict return outputs_dict
@abstractmethod @abstractmethod
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None: def load_checkpoint(
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
) -> None:
"""Load a model checkpoint gile and get ready for training or inference. """Load a model checkpoint gile and get ready for training or inference.
Args: Args:
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
checkpoint_path (str): Path to the model checkpoint file. checkpoint_path (str): Path to the model checkpoint file.
eval (bool, optional): If true, init model for inference else for training. Defaults to False. eval (bool, optional): If true, init model for inference else for training. Defaults to False.
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True. strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
""" """
... ...

View File

@ -398,9 +398,9 @@ class AlignTTS(BaseTTS):
logger.eval_audios(steps, audios, self.ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -92,16 +92,17 @@ class BaseTacotron(BaseTTS):
pass pass
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
"""Load model checkpoint and set up internals. """Load model checkpoint and set up internals.
Args: Args:
config (Coqpi): model configuration. config (Coqpi): model configuration.
checkpoint_path (str): path to checkpoint file. checkpoint_path (str): path to checkpoint file.
eval (bool): whether to load model for evaluation. eval (bool, optional): whether to load model for evaluation.
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
""" """
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
# TODO: set r in run-time by taking it from the new config # TODO: set r in run-time by taking it from the new config
if "r" in state: if "r" in state:

View File

@ -16,6 +16,7 @@ from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum
from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
from TTS.utils.io import load_fsspec
@dataclass @dataclass
@ -707,9 +708,9 @@ class ForwardTTS(BaseTTS):
logger.eval_audios(steps, audios, self.ap.sample_rate) logger.eval_audios(steps, audios, self.ap.sample_rate)
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -1686,14 +1686,10 @@ class Vits(BaseTTS):
return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)] return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)]
def load_checkpoint( def load_checkpoint(
self, self, config, checkpoint_path, eval=False, strict=True, cache=False
config,
checkpoint_path,
eval=False,
strict=True,
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference""" """Load the model checkpoint and setup for training or inference"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
# compat band-aid for the pre-trained models to not use the encoder baked into the model # compat band-aid for the pre-trained models to not use the encoder baked into the model
# TODO: consider baking the speaker encoder into the model and call it from there. # TODO: consider baking the speaker encoder into the model and call it from there.
# as it is probably easier for model distribution. # as it is probably easier for model distribution.

View File

@ -9,6 +9,8 @@ import fsspec
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
from TTS.utils.generic_utils import get_user_data_dir
class RenamingUnpickler(pickle_tts.Unpickler): class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem""" """Overload default pickler to solve module renaming problem"""
@ -57,6 +59,7 @@ def copy_model_files(config: Coqpit, out_path, new_fields=None):
def load_fsspec( def load_fsspec(
path: str, path: str,
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
cache: bool = True,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Like torch.load but can load from other locations (e.g. s3:// , gs://). """Like torch.load but can load from other locations (e.g. s3:// , gs://).
@ -64,21 +67,33 @@ def load_fsspec(
Args: Args:
path: Any path or url supported by fsspec. path: Any path or url supported by fsspec.
map_location: torch.device or str. map_location: torch.device or str.
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
**kwargs: Keyword arguments forwarded to torch.load. **kwargs: Keyword arguments forwarded to torch.load.
Returns: Returns:
Object stored in path. Object stored in path.
""" """
with fsspec.open(path, "rb") as f: is_local = os.path.isdir(path) or os.path.isfile(path)
return torch.load(f, map_location=map_location, **kwargs) if cache and not is_local:
with fsspec.open(
f"filecache::{path}",
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
mode="rb",
) as f:
return torch.load(f, map_location=map_location, **kwargs)
else:
with fsspec.open(path, "rb") as f:
return torch.load(f, map_location=map_location, **kwargs)
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin def load_checkpoint(
model, checkpoint_path, use_cuda=False, eval=False, cache=False
): # pylint: disable=redefined-builtin
try: try:
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
except ModuleNotFoundError: except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler pickle_tts.Unpickler = RenamingUnpickler
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
model.load_state_dict(state["model"]) model.load_state_dict(state["model"])
if use_cuda: if use_cuda:
model.cuda() model.cuda()

View File

@ -32,11 +32,14 @@ class ModelManager(object):
home path. home path.
Args: Args:
models_file (str): path to .model.json models_file (str): path to .model.json file. Defaults to None.
output_prefix (str): prefix to `tts` to download models. Defaults to None
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
""" """
def __init__(self, models_file=None, output_prefix=None): def __init__(self, models_file=None, output_prefix=None, progress_bar=False):
super().__init__() super().__init__()
self.progress_bar = progress_bar
if output_prefix is None: if output_prefix is None:
self.output_prefix = get_user_data_dir("tts") self.output_prefix = get_user_data_dir("tts")
else: else:
@ -236,7 +239,7 @@ class ModelManager(object):
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}") print(f" > Downloading model to {output_path}")
# download from github release # download from github release
self._download_zip_file(model_item["github_rls_url"], output_path) self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
self.print_model_license(model_item=model_item) self.print_model_license(model_item=model_item)
# find downloaded files # find downloaded files
output_model_path, output_config_path = self._find_files(output_path) output_model_path, output_config_path = self._find_files(output_path)
@ -334,7 +337,7 @@ class ModelManager(object):
config.save_json(config_path) config.save_json(config_path)
@staticmethod @staticmethod
def _download_zip_file(file_url, output_folder): def _download_zip_file(file_url, output_folder, progress_bar):
"""Download the github releases""" """Download the github releases"""
# download the file # download the file
r = requests.get(file_url, stream=True) r = requests.get(file_url, stream=True)
@ -342,11 +345,13 @@ class ModelManager(object):
try: try:
total_size_in_bytes = int(r.headers.get("content-length", 0)) total_size_in_bytes = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) if progress_bar:
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1]) temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
with open(temp_zip_name, "wb") as file: with open(temp_zip_name, "wb") as file:
for data in r.iter_content(block_size): for data in r.iter_content(block_size):
progress_bar.update(len(data)) if progress_bar:
progress_bar.update(len(data))
file.write(data) file.write(data)
with zipfile.ZipFile(temp_zip_name) as z: with zipfile.ZipFile(temp_zip_name) as z:
z.extractall(output_folder) z.extractall(output_folder)

View File

@ -231,6 +231,7 @@ class GAN(BaseVocoder):
config: Coqpit, config: Coqpit,
checkpoint_path: str, checkpoint_path: str,
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
cache: bool = False,
) -> None: ) -> None:
"""Load a GAN checkpoint and initialize model parameters. """Load a GAN checkpoint and initialize model parameters.
@ -239,7 +240,7 @@ class GAN(BaseVocoder):
checkpoint_path (str): Checkpoint file path. checkpoint_path (str): Checkpoint file path.
eval (bool, optional): If true, load the model for inference. If falseDefaults to False. eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
""" """
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
# band-aid for older than v0.0.15 GAN models # band-aid for older than v0.0.15 GAN models
if "model_disc" in state: if "model_disc" in state:
self.model_g.load_checkpoint(config, checkpoint_path, eval) self.model_g.load_checkpoint(config, checkpoint_path, eval)

View File

@ -290,9 +290,9 @@ class HifiganGenerator(torch.nn.Module):
remove_weight_norm(self.conv_post) remove_weight_norm(self.conv_post)
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -85,9 +85,9 @@ class MelganGenerator(nn.Module):
layer.remove_weight_norm() layer.remove_weight_norm()
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -153,9 +153,9 @@ class ParallelWaveganGenerator(torch.nn.Module):
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -218,9 +218,9 @@ class Wavegrad(BaseVocoder):
self.y_conv = weight_norm(self.y_conv) self.y_conv = weight_norm(self.y_conv)
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -542,9 +542,9 @@ class Wavernn(BaseVocoder):
return unfolded return unfolded
def load_checkpoint( def load_checkpoint(
self, config, checkpoint_path, eval=False self, config, checkpoint_path, eval=False, cache=False
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
self.load_state_dict(state["model"]) self.load_state_dict(state["model"])
if eval: if eval:
self.eval() self.eval()

View File

@ -15,7 +15,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
def test_GlowTTS(): def test_GlowTTS():
# set paths # set paths
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json") config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth") checkpoint_path = os.path.join(get_tests_output_path(), "glowtts.pth")
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/") output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config # load config
c = load_config(config_path) c = load_config(config_path)
@ -33,7 +33,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
def test_Tacotron2(): def test_Tacotron2():
# set paths # set paths
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json") config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth") checkpoint_path = os.path.join(get_tests_output_path(), "tacotron2.pth")
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/") output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config # load config
c = load_config(config_path) c = load_config(config_path)
@ -51,7 +51,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
def test_Tacotron(): def test_Tacotron():
# set paths # set paths
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json") config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth") checkpoint_path = os.path.join(get_tests_output_path(), "tacotron.pth")
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/") output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
# load config # load config
c = load_config(config_path) c = load_config(config_path)

View File

@ -15,7 +15,7 @@ def test_run_all_models():
print(" > Run synthesizer with all the models.") print(" > Run synthesizer with all the models.")
download_dir = get_user_data_dir("tts") download_dir = get_user_data_dir("tts")
output_path = os.path.join(get_tests_output_path(), "output.wav") output_path = os.path.join(get_tests_output_path(), "output.wav")
manager = ModelManager(output_prefix=get_tests_output_path()) manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
model_names = manager.list_models() model_names = manager.list_models()
for model_name in model_names: for model_name in model_names:
print(f"\n > Run - {model_name}") print(f"\n > Run - {model_name}")
@ -41,11 +41,14 @@ def test_run_all_models():
speaker_id = list(speaker_manager.name_to_id.keys())[0] speaker_id = list(speaker_manager.name_to_id.keys())[0]
run_cli( run_cli(
f"tts --model_name {model_name} " f"tts --model_name {model_name} "
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" ' f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --progress_bar False'
) )
else: else:
# single-speaker model # single-speaker model
run_cli(f"tts --model_name {model_name} " f'--text "This is an example." --out_path "{output_path}"') run_cli(
f"tts --model_name {model_name} "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
)
# remove downloaded models # remove downloaded models
shutil.rmtree(download_dir) shutil.rmtree(download_dir)
else: else:
@ -67,5 +70,5 @@ def test_voice_conversion():
output_path = os.path.join(get_tests_output_path(), "output.wav") output_path = os.path.join(get_tests_output_path(), "output.wav")
run_cli( run_cli(
f"tts --model_name {model_name}" f"tts --model_name {model_name}"
f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} " f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} --progress_bar False"
) )