Load fairseq models

This commit is contained in:
Eren G??lge 2023-05-23 13:11:04 +02:00
parent aef7f6d980
commit fb31ce4b0a
2 changed files with 67 additions and 6 deletions

View File

@ -25,11 +25,12 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
from TTS.tts.utils.languages import LanguageManager
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec
@ -1723,6 +1724,31 @@ class Vits(BaseTTS):
self.eval()
assert not self.training
def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False):
import json
# set paths
config_file = os.path.join(checkpoint_dir, "config.json")
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
vocab_file = os.path.join(checkpoint_dir, "vocab.txt")
# set config params
with open(config_file, 'r') as file:
# Load the JSON data as a dictionary
config_org = json.load(file)
self.config.audio.sample_rate = config_org['data']['sampling_rate']
# self.config.add_blank = config['add_blank']
# set tokenizer
vocab = FairseqVocab(vocab_file)
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
self.tokenizer = TTSTokenizer(
use_phonemes=False, text_cleaner=None, characters=vocab, phonemizer=None, add_blank=config_org['data']['add_blank'], use_eos_bos=False
)
# load fairseq checkpoint
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
self.load_state_dict(new_chk)
if eval:
self.eval()
assert not self.training
@staticmethod
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
"""Initiate model from config
@ -1919,3 +1945,23 @@ class VitsCharacters(BaseCharacters):
is_unique=False,
is_sorted=True,
)
class FairseqVocab(BaseVocabulary):
def __init__(self, vocab: str):
super(FairseqVocab).__init__()
self.vocab = vocab
@property
def vocab(self):
"""Return the vocabulary dictionary."""
return self._vocab
@vocab.setter
def vocab(self, vocab_file):
self._vocab = [x.replace("\n", "") for x in open(vocab_file).readlines()]
self.blank = self._vocab[0]
print(self._vocab)
self.pad = " "
self._char_to_id = {s: i for i, s in enumerate(self._vocab)}
self._id_to_char = {i: s for i, s in enumerate(self._vocab)}

View File

@ -71,11 +71,15 @@ class BaseVocabulary:
@vocab.setter
def vocab(self, vocab):
"""Set the vocabulary dictionary and character mapping dictionaries."""
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
self._vocab, self._char_to_id, self._id_to_char = None, None, None
if vocab is not None:
self._vocab = vocab
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
self._id_to_char = {
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
}
@staticmethod
def init_from_config(config, **kwargs):
@ -93,6 +97,17 @@ class BaseVocabulary:
)
return BaseVocabulary(**kwargs), config
def to_config(self) -> "CharactersConfig":
return CharactersConfig(
vocab_dict=self._vocab,
pad=self.pad,
eos=self.eos,
bos=self.bos,
blank=self.blank,
is_unique=False,
is_sorted=False,
)
@property
def num_chars(self):
"""Return number of tokens in the vocabulary."""