From fb31ce4b0a7774b74babcacd501f9841f5490e12 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Tue, 23 May 2023 13:11:04 +0200 Subject: [PATCH] Load fairseq models --- TTS/tts/models/vits.py | 48 +++++++++++++++++++++++++++++++- TTS/tts/utils/text/characters.py | 25 +++++++++++++---- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 2e0c32c8..881fcb33 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -25,11 +25,12 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask from TTS.tts.utils.languages import LanguageManager from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment from TTS.utils.io import load_fsspec @@ -1723,6 +1724,31 @@ class Vits(BaseTTS): self.eval() assert not self.training + def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False): + import json + # set paths + config_file = os.path.join(checkpoint_dir, "config.json") + checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth") + vocab_file = os.path.join(checkpoint_dir, "vocab.txt") + # set config params + with open(config_file, 'r') as file: + # Load the JSON data as a dictionary + config_org = json.load(file) + self.config.audio.sample_rate = config_org['data']['sampling_rate'] + # self.config.add_blank = config['add_blank'] + # set tokenizer + vocab = FairseqVocab(vocab_file) + self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels) + self.tokenizer = TTSTokenizer( + use_phonemes=False, text_cleaner=None, characters=vocab, phonemizer=None, add_blank=config_org['data']['add_blank'], use_eos_bos=False + ) + # load fairseq checkpoint + new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file) + self.load_state_dict(new_chk) + if eval: + self.eval() + assert not self.training + @staticmethod def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True): """Initiate model from config @@ -1919,3 +1945,23 @@ class VitsCharacters(BaseCharacters): is_unique=False, is_sorted=True, ) + + +class FairseqVocab(BaseVocabulary): + def __init__(self, vocab: str): + super(FairseqVocab).__init__() + self.vocab = vocab + + @property + def vocab(self): + """Return the vocabulary dictionary.""" + return self._vocab + + @vocab.setter + def vocab(self, vocab_file): + self._vocab = [x.replace("\n", "") for x in open(vocab_file).readlines()] + self.blank = self._vocab[0] + print(self._vocab) + self.pad = " " + self._char_to_id = {s: i for i, s in enumerate(self._vocab)} + self._id_to_char = {i: s for i, s in enumerate(self._vocab)} diff --git a/TTS/tts/utils/text/characters.py b/TTS/tts/utils/text/characters.py index 1b375e4f..2477aefc 100644 --- a/TTS/tts/utils/text/characters.py +++ b/TTS/tts/utils/text/characters.py @@ -71,11 +71,15 @@ class BaseVocabulary: @vocab.setter def vocab(self, vocab): """Set the vocabulary dictionary and character mapping dictionaries.""" - self._vocab = vocab - self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} - self._id_to_char = { - idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension - } + self._vocab, self._char_to_id, self._id_to_char = None, None, None + if vocab is not None: + self._vocab = vocab + self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)} + self._id_to_char = { + idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension + } + + @staticmethod def init_from_config(config, **kwargs): @@ -93,6 +97,17 @@ class BaseVocabulary: ) return BaseVocabulary(**kwargs), config + def to_config(self) -> "CharactersConfig": + return CharactersConfig( + vocab_dict=self._vocab, + pad=self.pad, + eos=self.eos, + bos=self.bos, + blank=self.blank, + is_unique=False, + is_sorted=False, + ) + @property def num_chars(self): """Return number of tokens in the vocabulary."""