This commit is contained in:
Eren G??lge 2023-05-24 11:55:48 +02:00
parent a8ce0144c2
commit d39878eac0
2 changed files with 7 additions and 3 deletions

View File

@ -1724,7 +1724,9 @@ class Vits(BaseTTS):
self.eval() self.eval()
assert not self.training assert not self.training
def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False): # pylint: disable=unused-argument, redefined-builtin def load_fairseq_checkpoint(
self, config, checkpoint_dir, eval=False
): # pylint: disable=unused-argument, redefined-builtin
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms """Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
Performs some changes for compatibility. Performs some changes for compatibility.
@ -1735,6 +1737,8 @@ class Vits(BaseTTS):
""" """
import json import json
from TTS.tts.utils.text.cleaners import basic_cleaners
self.disc = None self.disc = None
# set paths # set paths
config_file = os.path.join(checkpoint_dir, "config.json") config_file = os.path.join(checkpoint_dir, "config.json")
@ -1751,7 +1755,7 @@ class Vits(BaseTTS):
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels) self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
self.tokenizer = TTSTokenizer( self.tokenizer = TTSTokenizer(
use_phonemes=False, use_phonemes=False,
text_cleaner=None, text_cleaner=basic_cleaners,
characters=vocab, characters=vocab,
phonemizer=None, phonemizer=None,
add_blank=config_org["data"]["add_blank"], add_blank=config_org["data"]["add_blank"],

View File

@ -2,7 +2,7 @@ import torch
def rehash_fairseq_vits_checkpoint(checkpoint_file): def rehash_fairseq_vits_checkpoint(checkpoint_file):
chk = torch.load(checkpoint_file, map_location=torch.device('cpu'))["model"] chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"]
new_chk = {} new_chk = {}
for k, v in chk.items(): for k, v in chk.items():
if "enc_p." in k: if "enc_p." in k: