diff --git a/TTS/tts/utils/fairseq.py b/TTS/tts/utils/fairseq.py index 99c081ae..b1012958 100644 --- a/TTS/tts/utils/fairseq.py +++ b/TTS/tts/utils/fairseq.py @@ -2,7 +2,7 @@ import torch def rehash_fairseq_vits_checkpoint(checkpoint_file): - chk = torch.load(checkpoint_file)["model"] + chk = torch.load(checkpoint_file, map_location=torch.device('cpu'))["model"] new_chk = {} for k, v in chk.items(): if "enc_p." in k: