mirror of https://github.com/coqui-ai/TTS.git
Load fairseq models
This commit is contained in:
parent
aef7f6d980
commit
fb31ce4b0a
|
@ -25,11 +25,12 @@ from TTS.tts.layers.vits.discriminator import VitsDiscriminator
|
|||
from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder
|
||||
from TTS.tts.layers.vits.stochastic_duration_predictor import StochasticDurationPredictor
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.fairseq import rehash_fairseq_vits_checkpoint
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, rand_segments, segment, sequence_mask
|
||||
from TTS.tts.utils.languages import LanguageManager
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.io import load_fsspec
|
||||
|
@ -1723,6 +1724,31 @@ class Vits(BaseTTS):
|
|||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def load_fairseq_checkpoint(self, config, checkpoint_dir, eval=False):
|
||||
import json
|
||||
# set paths
|
||||
config_file = os.path.join(checkpoint_dir, "config.json")
|
||||
checkpoint_file = os.path.join(checkpoint_dir, "G_100000.pth")
|
||||
vocab_file = os.path.join(checkpoint_dir, "vocab.txt")
|
||||
# set config params
|
||||
with open(config_file, 'r') as file:
|
||||
# Load the JSON data as a dictionary
|
||||
config_org = json.load(file)
|
||||
self.config.audio.sample_rate = config_org['data']['sampling_rate']
|
||||
# self.config.add_blank = config['add_blank']
|
||||
# set tokenizer
|
||||
vocab = FairseqVocab(vocab_file)
|
||||
self.text_encoder.emb = nn.Embedding(vocab.num_chars, config.model_args.hidden_channels)
|
||||
self.tokenizer = TTSTokenizer(
|
||||
use_phonemes=False, text_cleaner=None, characters=vocab, phonemizer=None, add_blank=config_org['data']['add_blank'], use_eos_bos=False
|
||||
)
|
||||
# load fairseq checkpoint
|
||||
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
|
||||
self.load_state_dict(new_chk)
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]] = None, verbose=True):
|
||||
"""Initiate model from config
|
||||
|
@ -1919,3 +1945,23 @@ class VitsCharacters(BaseCharacters):
|
|||
is_unique=False,
|
||||
is_sorted=True,
|
||||
)
|
||||
|
||||
|
||||
class FairseqVocab(BaseVocabulary):
|
||||
def __init__(self, vocab: str):
|
||||
super(FairseqVocab).__init__()
|
||||
self.vocab = vocab
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
"""Return the vocabulary dictionary."""
|
||||
return self._vocab
|
||||
|
||||
@vocab.setter
|
||||
def vocab(self, vocab_file):
|
||||
self._vocab = [x.replace("\n", "") for x in open(vocab_file).readlines()]
|
||||
self.blank = self._vocab[0]
|
||||
print(self._vocab)
|
||||
self.pad = " "
|
||||
self._char_to_id = {s: i for i, s in enumerate(self._vocab)}
|
||||
self._id_to_char = {i: s for i, s in enumerate(self._vocab)}
|
||||
|
|
|
@ -71,11 +71,15 @@ class BaseVocabulary:
|
|||
@vocab.setter
|
||||
def vocab(self, vocab):
|
||||
"""Set the vocabulary dictionary and character mapping dictionaries."""
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
self._vocab, self._char_to_id, self._id_to_char = None, None, None
|
||||
if vocab is not None:
|
||||
self._vocab = vocab
|
||||
self._char_to_id = {char: idx for idx, char in enumerate(self._vocab)}
|
||||
self._id_to_char = {
|
||||
idx: char for idx, char in enumerate(self._vocab) # pylint: disable=unnecessary-comprehension
|
||||
}
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def init_from_config(config, **kwargs):
|
||||
|
@ -93,6 +97,17 @@ class BaseVocabulary:
|
|||
)
|
||||
return BaseVocabulary(**kwargs), config
|
||||
|
||||
def to_config(self) -> "CharactersConfig":
|
||||
return CharactersConfig(
|
||||
vocab_dict=self._vocab,
|
||||
pad=self.pad,
|
||||
eos=self.eos,
|
||||
bos=self.bos,
|
||||
blank=self.blank,
|
||||
is_unique=False,
|
||||
is_sorted=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def num_chars(self):
|
||||
"""Return number of tokens in the vocabulary."""
|
||||
|
|
Loading…
Reference in New Issue