diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 9f54cb39..112ef046 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -16,7 +16,6 @@ from TTS.tts.models import setup_model from TTS.tts.utils.speakers import get_speaker_manager from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters -from TTS.utils.io import load_fsspec use_cuda = torch.cuda.is_available() @@ -239,8 +238,7 @@ def main(args): # pylint: disable=redefined-outer-name model = setup_model(c) # restore model - checkpoint = load_fsspec(args.checkpoint_path, map_location="cpu") - model.load_state_dict(checkpoint["model"]) + model.load_checkpoint(c, args.checkpoint_path, eval=True) if use_cuda: model.cuda()