diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index becf22f9..6dd7ca72 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -9,7 +9,7 @@ from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.model import BaseModel +from TTS.model import BaseTrainerModel from TTS.tts.datasets.dataset import TTSDataset from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler @@ -19,27 +19,22 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram # pylint: skip-file -class BaseTTS(BaseModel): +class BaseTTS(BaseTrainerModel): """Base `tts` class. Every new `tts` model must inherit this. It defines common `tts` specific functions on top of `Model` implementation. - - Notes on input/output tensor shapes: - Any input or output tensor of the model must be shaped as - - - 3D tensors `batch x time x channels` - - 2D tensors `batch x channels` - - 1D tensors `batch x 1` """ def __init__( - self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None + self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None, + language_manager: LanguageManager = None ): - super().__init__(config) + super().__init__() self.config = config self.ap = ap self.tokenizer = tokenizer self.speaker_manager = speaker_manager + self.language_manager = language_manager self._set_model_args(config) def _set_model_args(self, config: Coqpit): @@ -262,7 +257,7 @@ class BaseTTS(BaseModel): d_vector_mapping = None # setup multi-lingual attributes - if hasattr(self, "language_manager"): + if hasattr(self, "language_manager") and self.language_manager is not None: language_id_mapping = ( self.language_manager.language_id_mapping if self.args.use_language_embedding else None ) diff --git a/TTS/tts/utils/text/tokenizer.py b/TTS/tts/utils/text/tokenizer.py index bdaf8ea6..50a5f519 100644 --- a/TTS/tts/utils/text/tokenizer.py +++ b/TTS/tts/utils/text/tokenizer.py @@ -119,6 +119,10 @@ class TTSTokenizer: return [self.characters.bos] + list(char_sequence) + [self.characters.eos] def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False): + """Intersperses the blank character between characters in a sequence. + + Use the ```blank``` character if defined else use the ```pad``` character. + """ char_to_use = self.characters.blank if use_blank_char else self.characters.pad result = [char_to_use] * (len(char_sequence) * 2 + 1) result[1::2] = char_sequence