Update BaseTTS

This commit is contained in:
Eren Gölge 2021-10-20 18:18:22 +00:00
parent 330ee7d208
commit 7c2cb7cc30
1 changed files with 25 additions and 15 deletions

View File

@ -1,4 +1,5 @@
import os import os
import random
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
@ -9,12 +10,12 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
# pylint: skip-file # pylint: skip-file
@ -64,7 +65,7 @@ class BaseTTS(BaseModel):
else: else:
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
config.characters = parse_symbols() config.characters = CharactersConfig(**parse_symbols())
model_characters = phonemes if config.use_phonemes else symbols model_characters = phonemes if config.use_phonemes else symbols
num_chars = len(model_characters) + getattr(config, "add_blank", False) num_chars = len(model_characters) + getattr(config, "add_blank", False)
return model_characters, config, num_chars return model_characters, config, num_chars
@ -80,14 +81,13 @@ class BaseTTS(BaseModel):
config (Coqpit): Model configuration. config (Coqpit): Model configuration.
""" """
# init speaker manager # init speaker manager
if self.speaker_manager is None: if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file):
raise ValueError(" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model.") raise ValueError(
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
print(f" > Number of speakers : {len(self.speaker_manager.speaker_ids)}") )
# set number of speakers
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers self.num_speakers = self.speaker_manager.num_speakers
# set ultimate speaker embedding size # set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file: if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = ( self.embedded_speaker_dim = (
@ -99,10 +99,6 @@ class BaseTTS(BaseModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None}
def format_batch(self, batch: Dict) -> Dict: def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`. """Generic batch formatting for `TTSDataset`.
@ -293,6 +289,20 @@ class BaseTTS(BaseModel):
) )
return loader return loader
def _get_test_aux_input(
self,
) -> Dict:
aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
"d_vector": None
if not self.config.use_d_vector_file
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
"style_wav": None, # TODO: handle GST style input
}
return aux_inputs
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]: def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`. """Generic test run for `tts` models used by `Trainer`.
@ -309,7 +319,7 @@ class BaseTTS(BaseModel):
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
test_sentences = self.config.test_sentences test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input() aux_inputs = self._get_test_aux_input()
for idx, sen in enumerate(test_sentences): for idx, sen in enumerate(test_sentences):
outputs_dict = synthesis( outputs_dict = synthesis(
self, self,