Implement BaseTTS

This commit is contained in:
Eren Gölge 2022-02-20 11:50:13 +01:00
parent 2bad098625
commit 35fc7270ff
2 changed files with 11 additions and 12 deletions

View File

@ -9,7 +9,7 @@ from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.model import BaseModel
from TTS.model import BaseTrainerModel
from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
@ -19,27 +19,22 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
class BaseTTS(BaseModel):
class BaseTTS(BaseTrainerModel):
"""Base `tts` class. Every new `tts` model must inherit this.
It defines common `tts` specific functions on top of `Model` implementation.
Notes on input/output tensor shapes:
Any input or output tensor of the model must be shaped as
- 3D tensors `batch x time x channels`
- 2D tensors `batch x channels`
- 1D tensors `batch x 1`
"""
def __init__(
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None,
language_manager: LanguageManager = None
):
super().__init__(config)
super().__init__()
self.config = config
self.ap = ap
self.tokenizer = tokenizer
self.speaker_manager = speaker_manager
self.language_manager = language_manager
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
@ -262,7 +257,7 @@ class BaseTTS(BaseModel):
d_vector_mapping = None
# setup multi-lingual attributes
if hasattr(self, "language_manager"):
if hasattr(self, "language_manager") and self.language_manager is not None:
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)

View File

@ -119,6 +119,10 @@ class TTSTokenizer:
return [self.characters.bos] + list(char_sequence) + [self.characters.eos]
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
"""Intersperses the blank character between characters in a sequence.
Use the ```blank``` character if defined else use the ```pad``` character.
"""
char_to_use = self.characters.blank if use_blank_char else self.characters.pad
result = [char_to_use] * (len(char_sequence) * 2 + 1)
result[1::2] = char_sequence