mirror of https://github.com/coqui-ai/TTS.git
Implement BaseTTS
This commit is contained in:
parent
2bad098625
commit
35fc7270ff
|
@ -9,7 +9,7 @@ from torch import nn
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.model import BaseTrainerModel
|
||||
from TTS.tts.datasets.dataset import TTSDataset
|
||||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||
|
@ -19,27 +19,22 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseTTS(BaseModel):
|
||||
class BaseTTS(BaseTrainerModel):
|
||||
"""Base `tts` class. Every new `tts` model must inherit this.
|
||||
|
||||
It defines common `tts` specific functions on top of `Model` implementation.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None
|
||||
self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None,
|
||||
language_manager: LanguageManager = None
|
||||
):
|
||||
super().__init__(config)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.tokenizer = tokenizer
|
||||
self.speaker_manager = speaker_manager
|
||||
self.language_manager = language_manager
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
|
@ -262,7 +257,7 @@ class BaseTTS(BaseModel):
|
|||
d_vector_mapping = None
|
||||
|
||||
# setup multi-lingual attributes
|
||||
if hasattr(self, "language_manager"):
|
||||
if hasattr(self, "language_manager") and self.language_manager is not None:
|
||||
language_id_mapping = (
|
||||
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||
)
|
||||
|
|
|
@ -119,6 +119,10 @@ class TTSTokenizer:
|
|||
return [self.characters.bos] + list(char_sequence) + [self.characters.eos]
|
||||
|
||||
def intersperse_blank_char(self, char_sequence: List[str], use_blank_char: bool = False):
|
||||
"""Intersperses the blank character between characters in a sequence.
|
||||
|
||||
Use the ```blank``` character if defined else use the ```pad``` character.
|
||||
"""
|
||||
char_to_use = self.characters.blank if use_blank_char else self.characters.pad
|
||||
result = [char_to_use] * (len(char_sequence) * 2 + 1)
|
||||
result[1::2] = char_sequence
|
||||
|
|
Loading…
Reference in New Issue