mirror of https://github.com/coqui-ai/TTS.git
Update BaseTTS
This commit is contained in:
parent
330ee7d208
commit
7c2cb7cc30
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -9,12 +10,12 @@ from torch.utils.data import DataLoader
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
@ -64,7 +65,7 @@ class BaseTTS(BaseModel):
|
|||
else:
|
||||
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||
|
||||
config.characters = parse_symbols()
|
||||
config.characters = CharactersConfig(**parse_symbols())
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
||||
return model_characters, config, num_chars
|
||||
|
@ -80,14 +81,13 @@ class BaseTTS(BaseModel):
|
|||
config (Coqpit): Model configuration.
|
||||
"""
|
||||
# init speaker manager
|
||||
if self.speaker_manager is None:
|
||||
raise ValueError(" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model.")
|
||||
|
||||
print(f" > Number of speakers : {len(self.speaker_manager.speaker_ids)}")
|
||||
|
||||
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
|
||||
if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file):
|
||||
raise ValueError(
|
||||
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||
)
|
||||
# set number of speakers
|
||||
if self.speaker_manager is not None:
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# set ultimate speaker embedding size
|
||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
|
@ -99,10 +99,6 @@ class BaseTTS(BaseModel):
|
|||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
return {"speaker_id": None, "style_wav": None, "d_vector": None}
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
||||
|
@ -293,6 +289,20 @@ class BaseTTS(BaseModel):
|
|||
)
|
||||
return loader
|
||||
|
||||
def _get_test_aux_input(
|
||||
self,
|
||||
) -> Dict:
|
||||
aux_inputs = {
|
||||
"speaker_id": None
|
||||
if not self.config.use_speaker_embedding
|
||||
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||
"d_vector": None
|
||||
if not self.config.use_d_vector_file
|
||||
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
||||
"style_wav": None, # TODO: handle GST style input
|
||||
}
|
||||
return aux_inputs
|
||||
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
|
@ -309,7 +319,7 @@ class BaseTTS(BaseModel):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self.get_aux_input()
|
||||
aux_inputs = self._get_test_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
outputs_dict = synthesis(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue