mirror of https://github.com/coqui-ai/TTS.git
Update BaseTTS
This commit is contained in:
parent
330ee7d208
commit
7c2cb7cc30
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -9,12 +10,12 @@ from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
from TTS.model import BaseModel
|
from TTS.model import BaseModel
|
||||||
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
from TTS.tts.datasets.dataset import TTSDataset
|
from TTS.tts.datasets.dataset import TTSDataset
|
||||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text import make_symbols
|
from TTS.tts.utils.text import make_symbols
|
||||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
|
|
||||||
# pylint: skip-file
|
# pylint: skip-file
|
||||||
|
|
||||||
|
@ -64,7 +65,7 @@ class BaseTTS(BaseModel):
|
||||||
else:
|
else:
|
||||||
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||||
|
|
||||||
config.characters = parse_symbols()
|
config.characters = CharactersConfig(**parse_symbols())
|
||||||
model_characters = phonemes if config.use_phonemes else symbols
|
model_characters = phonemes if config.use_phonemes else symbols
|
||||||
num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
||||||
return model_characters, config, num_chars
|
return model_characters, config, num_chars
|
||||||
|
@ -80,14 +81,13 @@ class BaseTTS(BaseModel):
|
||||||
config (Coqpit): Model configuration.
|
config (Coqpit): Model configuration.
|
||||||
"""
|
"""
|
||||||
# init speaker manager
|
# init speaker manager
|
||||||
if self.speaker_manager is None:
|
if self.speaker_manager is None and (config.use_speaker_embedding or config.use_d_vector_file):
|
||||||
raise ValueError(" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model.")
|
raise ValueError(
|
||||||
|
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
||||||
print(f" > Number of speakers : {len(self.speaker_manager.speaker_ids)}")
|
)
|
||||||
|
# set number of speakers
|
||||||
# set number of speakers - if num_speakers is set in config, use it, otherwise use speaker_manager
|
if self.speaker_manager is not None:
|
||||||
self.num_speakers = self.speaker_manager.num_speakers
|
self.num_speakers = self.speaker_manager.num_speakers
|
||||||
|
|
||||||
# set ultimate speaker embedding size
|
# set ultimate speaker embedding size
|
||||||
if config.use_speaker_embedding or config.use_d_vector_file:
|
if config.use_speaker_embedding or config.use_d_vector_file:
|
||||||
self.embedded_speaker_dim = (
|
self.embedded_speaker_dim = (
|
||||||
|
@ -99,10 +99,6 @@ class BaseTTS(BaseModel):
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
def get_aux_input(self, **kwargs) -> Dict:
|
|
||||||
"""Prepare and return `aux_input` used by `forward()`"""
|
|
||||||
return {"speaker_id": None, "style_wav": None, "d_vector": None}
|
|
||||||
|
|
||||||
def format_batch(self, batch: Dict) -> Dict:
|
def format_batch(self, batch: Dict) -> Dict:
|
||||||
"""Generic batch formatting for `TTSDataset`.
|
"""Generic batch formatting for `TTSDataset`.
|
||||||
|
|
||||||
|
@ -293,6 +289,20 @@ class BaseTTS(BaseModel):
|
||||||
)
|
)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
def _get_test_aux_input(
|
||||||
|
self,
|
||||||
|
) -> Dict:
|
||||||
|
aux_inputs = {
|
||||||
|
"speaker_id": None
|
||||||
|
if not self.config.use_speaker_embedding
|
||||||
|
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
|
||||||
|
"d_vector": None
|
||||||
|
if not self.config.use_d_vector_file
|
||||||
|
else random.samples(sorted(self.speaker_manager.d_vectors.values()), 1),
|
||||||
|
"style_wav": None, # TODO: handle GST style input
|
||||||
|
}
|
||||||
|
return aux_inputs
|
||||||
|
|
||||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||||
"""Generic test run for `tts` models used by `Trainer`.
|
"""Generic test run for `tts` models used by `Trainer`.
|
||||||
|
|
||||||
|
@ -309,7 +319,7 @@ class BaseTTS(BaseModel):
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
test_sentences = self.config.test_sentences
|
test_sentences = self.config.test_sentences
|
||||||
aux_inputs = self.get_aux_input()
|
aux_inputs = self._get_test_aux_input()
|
||||||
for idx, sen in enumerate(test_sentences):
|
for idx, sen in enumerate(test_sentences):
|
||||||
outputs_dict = synthesis(
|
outputs_dict = synthesis(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue