This commit is contained in:
Eren G??lge 2023-05-24 11:55:48 +02:00
parent a8ce0144c2
commit d39878eac0
2 changed files with 7 additions and 3 deletions

View File

@ -1724,7 +1724,9 @@ class Vits(BaseTTS):
self.eval()
assert not self.training
def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False): # pylint: disable=unused-argument, redefined-builtin
def load_fairseq_checkpoint(
self, config, checkpoint_dir, eval=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
Performs some changes for compatibility.
@ -1735,6 +1737,8 @@ class Vits(BaseTTS):
"""
import json
from TTS.tts.utils.text.cleaners import basic_cleaners
self.disc = None
# set paths
config_file = os.path.join(checkpoint_dir, "config.json")
@ -1751,7 +1755,7 @@ class Vits(BaseTTS):
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
self.tokenizer = TTSTokenizer(
use_phonemes=False,
text_cleaner=None,
text_cleaner=basic_cleaners,
characters=vocab,
phonemizer=None,
add_blank=config_org["data"]["add_blank"],

View File

@ -2,7 +2,7 @@ import torch
def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device('cpu'))["model"]
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"]
new_chk = {}
for k, v in chk.items():
if "enc_p." in k: