mirror of https://github.com/coqui-ai/TTS.git
Fixup
This commit is contained in:
parent
a8ce0144c2
commit
d39878eac0
|
@ -1724,7 +1724,9 @@ class Vits(BaseTTS):
|
||||||
self.eval()
|
self.eval()
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
def load_fairseq_checkpoint(
|
||||||
|
self, config, checkpoint_dir, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
|
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
|
||||||
Performs some changes for compatibility.
|
Performs some changes for compatibility.
|
||||||
|
|
||||||
|
@ -1735,6 +1737,8 @@ class Vits(BaseTTS):
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from TTS.tts.utils.text.cleaners import basic_cleaners
|
||||||
|
|
||||||
self.disc = None
|
self.disc = None
|
||||||
# set paths
|
# set paths
|
||||||
config_file = os.path.join(checkpoint_dir, "config.json")
|
config_file = os.path.join(checkpoint_dir, "config.json")
|
||||||
|
@ -1751,7 +1755,7 @@ class Vits(BaseTTS):
|
||||||
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
|
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
|
||||||
self.tokenizer = TTSTokenizer(
|
self.tokenizer = TTSTokenizer(
|
||||||
use_phonemes=False,
|
use_phonemes=False,
|
||||||
text_cleaner=None,
|
text_cleaner=basic_cleaners,
|
||||||
characters=vocab,
|
characters=vocab,
|
||||||
phonemizer=None,
|
phonemizer=None,
|
||||||
add_blank=config_org["data"]["add_blank"],
|
add_blank=config_org["data"]["add_blank"],
|
||||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
|
|
||||||
|
|
||||||
def rehash_fairseq_vits_checkpoint(checkpoint_file):
|
def rehash_fairseq_vits_checkpoint(checkpoint_file):
|
||||||
chk = torch.load(checkpoint_file, map_location=torch.device('cpu'))["model"]
|
chk = torch.load(checkpoint_file, map_location=torch.device("cpu"))["model"]
|
||||||
new_chk = {}
|
new_chk = {}
|
||||||
for k, v in chk.items():
|
for k, v in chk.items():
|
||||||
if "enc_p." in k:
|
if "enc_p." in k:
|
||||||
|
|
Loading…
Reference in New Issue