Implement BaseTTS

This commit is contained in:
Eren Gölge 2022-02-20 11:50:13 +01:00
parent 2bad098625
commit 35fc7270ff
2 changed files with 11 additions and 12 deletions

View File

@ -9,7 +9,7 @@ from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
@ -19,27 +19,22 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file # pylint: skip-file
class BaseTTS(BaseModel): class BaseTTS(BaseTrainerModel):
"""Base `tts` class. Every new `tts` model must inherit this. """Base `tts` class. Every new `tts` model must inherit this.
It defines common `tts` specific functions on top of `Model` implementation. It defines common `tts` specific functions on top of `Model` implementation.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
""" """
def __init__( def __init__(
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None
): ):
super().__init__(config) super().__init__()
self.config = config self.config = config
self.ap = ap self.ap = ap
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.speaker_manager = speaker_manager self.speaker_manager = speaker_manager
self.language_manager = language_manager
self._set_model_args(config) self._set_model_args(config)
def _set_model_args(self, config: Coqpit): def _set_model_args(self, config: Coqpit):
@ -262,7 +257,7 @@ class BaseTTS(BaseModel):
d_vector_mapping = None d_vector_mapping = None
# setup multi-lingual attributes # setup multi-lingual attributes
if hasattr(self, "language_manager"): if hasattr(self, "language_manager") and self.language_manager is not None:
language_id_mapping = ( language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None self.language_manager.language_id_mapping if self.args.use_language_embedding else None
) )

View File

@ -119,6 +119,10 @@ class TTSTokenizer:
return [self.characters.bos] + list(char_sequence) + [self.characters.eos] return [self.characters.bos] + list(char_sequence) + [self.characters.eos]
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
"""Intersperses the blank character between characters in a sequence.
Use the ```blank``` character if defined else use the ```pad``` character.
"""
char_to_use = self.characters.blank if use_blank_char else self.characters.pad char_to_use = self.characters.blank if use_blank_char else self.characters.pad
result = [char_to_use] * (len(char_sequence) * 2 + 1) result = [char_to_use] * (len(char_sequence) * 2 + 1)
result[1::2] = char_sequence result[1::2] = char_sequence