diff --git a/TTS/tts/configs/shared_configs.py b/TTS/tts/configs/shared_configs.py index b101b70a..98461bdd 100644 --- a/TTS/tts/configs/shared_configs.py +++ b/TTS/tts/configs/shared_configs.py @@ -146,11 +146,19 @@ class BaseTTSConfig(BaseTrainingConfig): sort_by_audio_len (bool): If true, dataloder sorts the data by audio length else sorts by the input text length. Defaults to `False`. - min_seq_len (int): - Minimum sequence length to be used at training. + min_text_len (int): + Minimum length of input text to be used. All shorter samples will be ignored. Defaults to 0. - max_seq_len (int): - Maximum sequence length to be used at training. Larger values result in more VRAM usage. + max_text_len (int): + Maximum length of input text to be used. All longer samples will be ignored. Defaults to float("inf"). + + min_audio_len (int): + Minimum length of input audio to be used. All shorter samples will be ignored. Defaults to 0. + + max_audio_len (int): + Maximum length of input audio to be used. All longer samples will be ignored. The maximum length in the + dataset defines the VRAM used in the training. Hence, pay attention to this value if you encounter an + OOM error in training. Defaults to float("inf"). compute_f0 (int): (Not in use yet). @@ -211,8 +219,10 @@ class BaseTTSConfig(BaseTrainingConfig): loss_masking: bool = None # dataloading sort_by_audio_len: bool = False - min_seq_len: int = 1 - max_seq_len: int = float("inf") + min_audio_len: int = 1 + max_audio_len: int = float("inf") + min_text_len: int = 1 + max_text_len: int = float("inf") compute_f0: bool = False compute_linear_spec: bool = False use_noise_augment: bool = False diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index b9b4ed57..27231790 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -168,8 +168,8 @@ class BaseTTS(BaseModel): Dict: [description] """ # setup input batch - text_input = batch["text"] - text_lengths = batch["text_lengths"] + text_input = batch["token_id"] + text_lengths = batch["token_id_lengths"] speaker_names = batch["speaker_names"] linear_input = batch["linear"] mel_input = batch["mel"] @@ -261,10 +261,6 @@ class BaseTTS(BaseModel): d_vector_mapping = None # setup custom symbols if needed - custom_symbols = None - if hasattr(self, "make_symbols"): - custom_symbols = self.make_symbols(self.config) - if hasattr(self, "language_manager"): language_id_mapping = ( self.language_manager.language_id_mapping if self.args.use_language_embedding else None @@ -282,8 +278,10 @@ class BaseTTS(BaseModel): ap=self.ap, return_wav=config.return_wav if "return_wav" in config else False, batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size, - min_seq_len=config.min_seq_len, - max_seq_len=config.max_seq_len, + min_text_len=config.min_text_len, + max_text_len=config.max_text_len, + min_audio_len=config.min_audio_len, + max_audio_len=config.max_audio_len, phoneme_cache_path=config.phoneme_cache_path, use_noise_augment=False if is_eval else config.use_noise_augment, verbose=verbose, @@ -292,45 +290,12 @@ class BaseTTS(BaseModel): tokenizer=self.tokenizer, ) - # pre-compute phonemes - if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]: - if hasattr(self, "eval_data_items") and is_eval: - dataset.items = self.eval_data_items - elif hasattr(self, "train_data_items") and not is_eval: - dataset.items = self.train_data_items - else: - # precompute phonemes for precise estimate of sequence lengths. - # otherwise `dataset.sort_items()` uses raw text lengths - dataset.compute_input_seq(config.num_loader_workers) - - # TODO: find a more efficient solution - # cheap hack - store items in the model state to avoid recomputing when reinit the dataset - if is_eval: - self.eval_data_items = dataset.items - else: - self.train_data_items = dataset.items - - # halt DDP processes for the main process to finish computing the phoneme cache + # wait all the DDP process to be ready if num_gpus > 1: dist.barrier() # sort input sequences from short to long - dataset.sort_and_filter_items(config.get("sort_by_audio_len", default=False)) - - # compute pitch frames and write to files. - if config.compute_f0 and rank in [None, 0]: - if not os.path.exists(config.f0_cache_path): - dataset.pitch_extractor.compute_pitch( - self.ap, config.get("f0_cache_path", None), config.num_loader_workers - ) - - # halt DDP processes for the main process to finish computing the F0 cache - if num_gpus > 1: - dist.barrier() - - # load pitch stats computed above by all the workers - if config.compute_f0: - dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + dataset.preprocess_samples() # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None