mirror of https://github.com/coqui-ai/TTS.git
Cache fsspec downloads (#2132)
* Cache fsspec downloaded files * Use diff paths for test * Make fsspec caching optional * Decom GPU docker tests * Make progress bar optional for better CI log * Check path local
This commit is contained in:
parent
c5412532ac
commit
8cb1433e6e
|
@ -15,7 +15,7 @@ jobs:
|
|||
matrix:
|
||||
arch: ["amd64"]
|
||||
base:
|
||||
- "nvcr.io/nvidia/pytorch:22.03-py3" # GPU enabled
|
||||
# - "nvcr.io/nvidia/pytorch:22.03-py3" # GPU enabled
|
||||
- "ubuntu:20.04" # CPU only
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
|
|
@ -238,6 +238,13 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
help="speaker ID of the reference_wav speaker (If not provided the embedding will be computed using the Speaker Encoder).",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress_bar",
|
||||
type=str2bool,
|
||||
help="If true shows a progress bar for the model download. Defaults to True",
|
||||
default=True,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# print the description if either text or list_models is not set
|
||||
|
@ -255,7 +262,7 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
|
||||
# load model manager
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
manager = ModelManager(path)
|
||||
manager = ModelManager(path, progress_bar=args.progress_bar)
|
||||
|
||||
model_path = None
|
||||
config_path = None
|
||||
|
|
|
@ -107,9 +107,15 @@ class BaseEncoder(nn.Module):
|
|||
return criterion
|
||||
|
||||
def load_checkpoint(
|
||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, use_cuda: bool = False, criterion=None
|
||||
self,
|
||||
config: Coqpit,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False,
|
||||
use_cuda: bool = False,
|
||||
criterion=None,
|
||||
cache=False,
|
||||
):
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
try:
|
||||
self.load_state_dict(state["model"])
|
||||
print(" > Model fully restored. ")
|
||||
|
|
|
@ -44,13 +44,16 @@ class BaseTrainerModel(TrainerModel):
|
|||
return outputs_dict
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True) -> None:
|
||||
def load_checkpoint(
|
||||
self, config: Coqpit, checkpoint_path: str, eval: bool = False, strict: bool = True, cache=False
|
||||
) -> None:
|
||||
"""Load a model checkpoint gile and get ready for training or inference.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
checkpoint_path (str): Path to the model checkpoint file.
|
||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
strcit (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||
strict (bool, optional): Match all checkpoint keys to model's keys. Defaults to True.
|
||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||
"""
|
||||
...
|
||||
|
|
|
@ -398,9 +398,9 @@ class AlignTTS(BaseTTS):
|
|||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -92,16 +92,17 @@ class BaseTacotron(BaseTTS):
|
|||
pass
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load model checkpoint and set up internals.
|
||||
|
||||
Args:
|
||||
config (Coqpi): model configuration.
|
||||
checkpoint_path (str): path to checkpoint file.
|
||||
eval (bool): whether to load model for evaluation.
|
||||
eval (bool, optional): whether to load model for evaluation.
|
||||
cache (bool, optional): If True, cache the file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to False.
|
||||
"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
# TODO: set r in run-time by taking it from the new config
|
||||
if "r" in state:
|
||||
|
|
|
@ -16,6 +16,7 @@ from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_spectrogram
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -707,9 +708,9 @@ class ForwardTTS(BaseTTS):
|
|||
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -1686,14 +1686,10 @@ class Vits(BaseTTS):
|
|||
return [VitsDiscriminatorLoss(self.config), VitsGeneratorLoss(self.config)]
|
||||
|
||||
def load_checkpoint(
|
||||
self,
|
||||
config,
|
||||
checkpoint_path,
|
||||
eval=False,
|
||||
strict=True,
|
||||
self, config, checkpoint_path, eval=False, strict=True, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||
# as it is probably easier for model distribution.
|
||||
|
|
|
@ -9,6 +9,8 @@ import fsspec
|
|||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
|
||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||
"""Overload default pickler to solve module renaming problem"""
|
||||
|
@ -57,6 +59,7 @@ def copy_model_files(config: Coqpit, out_path, new_fields=None):
|
|||
def load_fsspec(
|
||||
path: str,
|
||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||
cache: bool = True,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Like torch.load but can load from other locations (e.g. s3:// , gs://).
|
||||
|
@ -64,21 +67,33 @@ def load_fsspec(
|
|||
Args:
|
||||
path: Any path or url supported by fsspec.
|
||||
map_location: torch.device or str.
|
||||
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True.
|
||||
**kwargs: Keyword arguments forwarded to torch.load.
|
||||
|
||||
Returns:
|
||||
Object stored in path.
|
||||
"""
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
is_local = os.path.isdir(path) or os.path.isfile(path)
|
||||
if cache and not is_local:
|
||||
with fsspec.open(
|
||||
f"filecache::{path}",
|
||||
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))},
|
||||
mode="rb",
|
||||
) as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
else:
|
||||
with fsspec.open(path, "rb") as f:
|
||||
return torch.load(f, map_location=map_location, **kwargs)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||
def load_checkpoint(
|
||||
model, checkpoint_path, use_cuda=False, eval=False, cache=False
|
||||
): # pylint: disable=redefined-builtin
|
||||
try:
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache)
|
||||
model.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
|
|
@ -32,11 +32,14 @@ class ModelManager(object):
|
|||
home path.
|
||||
|
||||
Args:
|
||||
models_file (str): path to .model.json
|
||||
models_file (str): path to .model.json file. Defaults to None.
|
||||
output_prefix (str): prefix to `tts` to download models. Defaults to None
|
||||
progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, models_file=None, output_prefix=None):
|
||||
def __init__(self, models_file=None, output_prefix=None, progress_bar=False):
|
||||
super().__init__()
|
||||
self.progress_bar = progress_bar
|
||||
if output_prefix is None:
|
||||
self.output_prefix = get_user_data_dir("tts")
|
||||
else:
|
||||
|
@ -236,7 +239,7 @@ class ModelManager(object):
|
|||
os.makedirs(output_path, exist_ok=True)
|
||||
print(f" > Downloading model to {output_path}")
|
||||
# download from github release
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)
|
||||
self.print_model_license(model_item=model_item)
|
||||
# find downloaded files
|
||||
output_model_path, output_config_path = self._find_files(output_path)
|
||||
|
@ -334,7 +337,7 @@ class ModelManager(object):
|
|||
config.save_json(config_path)
|
||||
|
||||
@staticmethod
|
||||
def _download_zip_file(file_url, output_folder):
|
||||
def _download_zip_file(file_url, output_folder, progress_bar):
|
||||
"""Download the github releases"""
|
||||
# download the file
|
||||
r = requests.get(file_url, stream=True)
|
||||
|
@ -342,11 +345,13 @@ class ModelManager(object):
|
|||
try:
|
||||
total_size_in_bytes = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
if progress_bar:
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
|
||||
with open(temp_zip_name, "wb") as file:
|
||||
for data in r.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
if progress_bar:
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
with zipfile.ZipFile(temp_zip_name) as z:
|
||||
z.extractall(output_folder)
|
||||
|
|
|
@ -231,6 +231,7 @@ class GAN(BaseVocoder):
|
|||
config: Coqpit,
|
||||
checkpoint_path: str,
|
||||
eval: bool = False, # pylint: disable=unused-argument, redefined-builtin
|
||||
cache: bool = False,
|
||||
) -> None:
|
||||
"""Load a GAN checkpoint and initialize model parameters.
|
||||
|
||||
|
@ -239,7 +240,7 @@ class GAN(BaseVocoder):
|
|||
checkpoint_path (str): Checkpoint file path.
|
||||
eval (bool, optional): If true, load the model for inference. If falseDefaults to False.
|
||||
"""
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
# band-aid for older than v0.0.15 GAN models
|
||||
if "model_disc" in state:
|
||||
self.model_g.load_checkpoint(config, checkpoint_path, eval)
|
||||
|
|
|
@ -290,9 +290,9 @@ class HifiganGenerator(torch.nn.Module):
|
|||
remove_weight_norm(self.conv_post)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -85,9 +85,9 @@ class MelganGenerator(nn.Module):
|
|||
layer.remove_weight_norm()
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -153,9 +153,9 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -218,9 +218,9 @@ class Wavegrad(BaseVocoder):
|
|||
self.y_conv = weight_norm(self.y_conv)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -542,9 +542,9 @@ class Wavernn(BaseVocoder):
|
|||
return unfolded
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
|
|
|
@ -15,7 +15,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
|||
def test_GlowTTS():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_glow_tts.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "glowtts.pth")
|
||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
|
@ -33,7 +33,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
|||
def test_Tacotron2():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron2_config.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "tacotron2.pth")
|
||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
|
@ -51,7 +51,7 @@ class TestExtractTTSSpectrograms(unittest.TestCase):
|
|||
def test_Tacotron():
|
||||
# set paths
|
||||
config_path = os.path.join(get_tests_input_path(), "test_tacotron_config.json")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "checkpoint_test.pth")
|
||||
checkpoint_path = os.path.join(get_tests_output_path(), "tacotron.pth")
|
||||
output_path = os.path.join(get_tests_output_path(), "output_extract_tts_spectrograms/")
|
||||
# load config
|
||||
c = load_config(config_path)
|
||||
|
|
|
@ -15,7 +15,7 @@ def test_run_all_models():
|
|||
print(" > Run synthesizer with all the models.")
|
||||
download_dir = get_user_data_dir("tts")
|
||||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
manager = ModelManager(output_prefix=get_tests_output_path())
|
||||
manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
|
||||
model_names = manager.list_models()
|
||||
for model_name in model_names:
|
||||
print(f"\n > Run - {model_name}")
|
||||
|
@ -41,11 +41,14 @@ def test_run_all_models():
|
|||
speaker_id = list(speaker_manager.name_to_id.keys())[0]
|
||||
run_cli(
|
||||
f"tts --model_name {model_name} "
|
||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" '
|
||||
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --progress_bar False'
|
||||
)
|
||||
else:
|
||||
# single-speaker model
|
||||
run_cli(f"tts --model_name {model_name} " f'--text "This is an example." --out_path "{output_path}"')
|
||||
run_cli(
|
||||
f"tts --model_name {model_name} "
|
||||
f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
|
||||
)
|
||||
# remove downloaded models
|
||||
shutil.rmtree(download_dir)
|
||||
else:
|
||||
|
@ -67,5 +70,5 @@ def test_voice_conversion():
|
|||
output_path = os.path.join(get_tests_output_path(), "output.wav")
|
||||
run_cli(
|
||||
f"tts --model_name {model_name}"
|
||||
f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} "
|
||||
f" --out_path {output_path} --speaker_wav {speaker_wav} --reference_wav {reference_wav} --language_idx {language_id} --progress_bar False"
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue