Update BaseTTS

This commit is contained in:
Eren Gölge 2021-10-20 18:18:22 +00:00
parent 330ee7d208
commit 7c2cb7cc30
1 changed files with 25 additions and 15 deletions

View File

@ -1,4 +1,5 @@
import os
import random
from typing import Dict, List, Tuple
import torch
@ -9,12 +10,12 @@ from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
# pylint: skip-file
@ -64,7 +65,7 @@ class BaseTTS(BaseModel):
else:
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
config.characters = parse_symbols()
config.characters = CharactersConfig(**parse_symbols())
model_characters = phonemes if config.use_phonemes else symbols
num_chars = len(model_characters) + getattr(config, "add_blank", False)
return model_characters, config, num_chars
@ -80,14 +81,13 @@ class BaseTTS(BaseModel):
config (Coqpit): Model configuration.
"""
# init speaker manager
if self.speaker_manager is None:
raise ValueError(" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model.")
print(f" > Number of speakers : {len(self.speaker_manager.speaker_ids)}")
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
self.num_speakers = self.speaker_manager.num_speakers
if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file):
raise ValueError(
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
)
# set number of speakers
if self.speaker_manager is not None:
self.num_speakers = self.speaker_manager.num_speakers
# set ultimate speaker embedding size
if config.use_speaker_embedding or config.use_d_vector_file:
self.embedded_speaker_dim = (
@ -99,10 +99,6 @@ class BaseTTS(BaseModel):
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3)
def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None}
def format_batch(self, batch: Dict) -> Dict:
"""Generic batch formatting for `TTSDataset`.
@ -293,6 +289,20 @@ class BaseTTS(BaseModel):
)
return loader
def _get_test_aux_input(
self,
) -> Dict:
aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
"d_vector": None
if not self.config.use_d_vector_file
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
"style_wav": None, # TODO: handle GST style input
}
return aux_inputs
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
@ -309,7 +319,7 @@ class BaseTTS(BaseModel):
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self.get_aux_input()
aux_inputs = self._get_test_aux_input()
for idx, sen in enumerate(test_sentences):
outputs_dict = synthesis(
self,