diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index d28f188e..f2e7779c 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -13,13 +13,13 @@ from trainer.trainer_utils import get_optimizer from TTS.encoder.dataset import EncoderDataset from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model -from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.encoder.utils.training import init_training from TTS.encoder.utils.visual import plot_embeddings from TTS.tts.datasets import load_tts_samples from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import count_parameters, remove_experiment_folder from TTS.utils.io import copy_model_files +from TTS.utils.samplers import PerfectBatchSampler from TTS.utils.training import check_update torch.backends.cudnn.enabled = True diff --git a/TTS/tts/configs/vits_config.py b/TTS/tts/configs/vits_config.py index df9116f3..3469f701 100644 --- a/TTS/tts/configs/vits_config.py +++ b/TTS/tts/configs/vits_config.py @@ -70,6 +70,18 @@ class VitsConfig(BaseTTSConfig): compute_linear_spec (bool): If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + r (int): Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. @@ -124,6 +136,11 @@ class VitsConfig(BaseTTSConfig): return_wav: bool = True compute_linear_spec: bool = True + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + # overrides r: int = 1 # DO NOT CHANGE add_blank: bool = True diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index ef05ea7c..a4be2b33 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None): "audio_file": audio_path, "speaker_name": speaker_name if speaker_name is not None else row.speaker_name, "emotion_name": emotion_name if emotion_name is not None else row.emotion_name, + "root_path": root_path, } ) if not_found_counter > 0: @@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("\t") wav_file = os.path.join(root_path, cols[0] + ".wav") text = cols[1] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument wav_file = cols[1].strip() text = cols[0].strip() wav_file = os.path.join(root_path, "wavs", wav_file) - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume text = cols[1].strip() folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" wav_file = os.path.join(root_path, folder_name, wav_file) - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None): wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") if os.path.isfile(wav_file): text = cols[1].strip() - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path} + ) else: # M-AI-Labs have some missing samples, so just print the warning print("> File %s does not exist!" % (wav_file)) @@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2] - items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path} + ) return items @@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[1] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg if not os.path.exists(wav_file): print(f" [!] {wav_file} in metafile does not exist. Skipping...") continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") text = cols[1] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None): if speaker_name in ignored_speakers: continue wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) - items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path} + ) return items @@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_name in ignored_speakers: continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"}) + items.append( + { + "text": text, + "audio_file": wav_file, + "speaker_name": f"LTTS_{speaker_name}", + "root_path": root_path, + } + ) for item in items: assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}" return items @@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar skipped_files.append(wav_file) continue text = cols[1].strip() - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) print(f" [!] {len(skipped_files)} files skipped. They don't exist...") return items @@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_id in ignored_speakers: continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path}) return items @@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic else: wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") if os.path.exists(wav_file): - items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path} + ) else: print(f" [!] wav files don't exist - {wav_file}") return items @@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non with open(meta_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") - items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path} + ) return items @@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar if os.path.exists(txt_file) and os.path.exists(wav_file): with open(txt_file, "r", encoding="utf-8") as file_text: text = file_text.readlines()[0] - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno if ignore_digits_sentences and any(map(str.isdigit, text)): continue wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac") - items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path}) return items @@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker in ignored_speakers: continue - items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker}) + items.append( + {"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path} + ) return items @@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None): if isinstance(ignored_speakers, list): if speaker_id in ignored_speakers: continue - items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id}) + items.append( + {"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path} + ) return items @@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin for line in ttf: wav_name, text = line.rstrip("\n").split("|") wav_path = os.path.join(root_path, "clips_22", wav_name) - items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path}) return items @@ -554,5 +576,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[2].replace(" ", "") - items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) + items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}) return items diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 1198bada..5dae47cd 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field, replace from itertools import chain from typing import Dict, List, Tuple, Union +import numpy as np import torch import torch.distributed as dist import torchaudio @@ -13,6 +14,8 @@ from torch import nn from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader +from torch.utils.data.sampler import WeightedRandomSampler +from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig @@ -29,6 +32,8 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment +from TTS.utils.io import load_fsspec +from TTS.utils.samplers import BucketBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -221,6 +226,30 @@ class VitsAudioConfig(Coqpit): ############################## +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): + """Create inverse frequency weights for balancing the dataset. + Use `multi_dict` to scale relative weights.""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + # check if all keys are in the multi_dict + for k in multi_dict: + assert k in unique_attr_names, f"{k} not in {unique_attr_names}" + # scale weights + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + class VitsDataset(TTSDataset): def __init__(self, model_args, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1510,6 +1539,42 @@ class Vits(BaseTTS): batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) return batch + def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False): + weights = None + data_items = dataset.samples + if getattr(config, "use_weighted_sampler", False): + for attr_name, alpha in config.weighted_sampler_attrs.items(): + print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'") + multi_dict = config.weighted_sampler_multipliers.get(attr_name, None) + print(multi_dict) + weights, attr_names, attr_weights = get_attribute_balancer_weights( + attr_name=attr_name, items=data_items, multi_dict=multi_dict + ) + weights = weights * alpha + print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}") + + # input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items] + + if weights is not None: + w_sampler = WeightedRandomSampler(weights, len(weights)) + batch_sampler = BucketBatchSampler( + w_sampler, + data=data_items, + batch_size=config.eval_batch_size if is_eval else config.batch_size, + sort_key=lambda x: os.path.getsize(x["audio_file"]), + drop_last=True, + ) + else: + batch_sampler = None + # sampler for DDP + if batch_sampler is None: + batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None + else: # If a sampler is already defined use this sampler and DDP sampler together + batch_sampler = ( + DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler + ) # TODO: check batch_sampler with multi-gpu + return batch_sampler + def get_data_loader( self, config: Coqpit, @@ -1551,10 +1616,7 @@ class Vits(BaseTTS): loader = DataLoader( dataset, - batch_size=config.eval_batch_size if is_eval else config.batch_size, - shuffle=False, # shuffle is done in the dataset. - drop_last=False, # setting this False might cause issues in AMP training. - sampler=sampler, + batch_sampler=sampler, collate_fn=dataset.collate_fn, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, pin_memory=False, @@ -1615,7 +1677,7 @@ class Vits(BaseTTS): strict=True, ): # pylint: disable=unused-argument, redefined-builtin """Load the model checkpoint and setup for training or inference""" - state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) # compat band-aid for the pre-trained models to not use the encoder baked into the model # TODO: consider baking the speaker encoder into the model and call it from there. # as it is probably easier for model distribution. diff --git a/TTS/utils/samplers.py b/TTS/utils/samplers.py new file mode 100644 index 00000000..df5d4185 --- /dev/null +++ b/TTS/utils/samplers.py @@ -0,0 +1,202 @@ +import math +import random +from typing import Callable, List, Union + +from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler + + +class SubsetSampler(Sampler): + """ + Samples elements sequentially from a given list of indices. + + Args: + indices (list): a sequence of indices + """ + + def __init__(self, indices): + super().__init__(indices) + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in range(len(self.indices))) + + def __len__(self): + return len(self.indices) + + +class PerfectBatchSampler(Sampler): + """ + Samples a mini-batch of indices for a balanced class batching + + Args: + dataset_items(list): dataset items to sample from. + classes (list): list of classes of dataset_items to sample from. + batch_size (int): total number of samples to be sampled in a mini-batch. + num_gpus (int): number of GPU in the data parallel mode. + shuffle (bool): if True, samples randomly, otherwise samples sequentially. + drop_last (bool): if True, drops last incomplete batch. + """ + + def __init__( + self, + dataset_items, + classes, + batch_size, + num_classes_in_batch, + num_gpus=1, + shuffle=True, + drop_last=False, + label_key="class_name", + ): + super().__init__(dataset_items) + assert ( + batch_size % (num_classes_in_batch * num_gpus) == 0 + ), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)." + + label_indices = {} + for idx, item in enumerate(dataset_items): + label = item[label_key] + if label not in label_indices.keys(): + label_indices[label] = [idx] + else: + label_indices[label].append(idx) + + if shuffle: + self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes] + else: + self._samplers = [SubsetSampler(label_indices[key]) for key in classes] + + self._batch_size = batch_size + self._drop_last = drop_last + self._dp_devices = num_gpus + self._num_classes_in_batch = num_classes_in_batch + + def __iter__(self): + + batch = [] + if self._num_classes_in_batch != len(self._samplers): + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + else: + valid_samplers_idx = None + + iters = [iter(s) for s in self._samplers] + done = False + + while True: + b = [] + for i, it in enumerate(iters): + if valid_samplers_idx is not None and i not in valid_samplers_idx: + continue + idx = next(it, None) + if idx is None: + done = True + break + b.append(idx) + if done: + break + batch += b + if len(batch) == self._batch_size: + yield batch + batch = [] + if valid_samplers_idx is not None: + valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch) + + if not self._drop_last: + if len(batch) > 0: + groups = len(batch) // self._num_classes_in_batch + if groups % self._dp_devices == 0: + yield batch + else: + batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch] + if len(batch) > 0: + yield batch + + def __len__(self): + class_batch_size = self._batch_size // self._num_classes_in_batch + return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers) + + +def identity(x): + return x + + +class SortedSampler(Sampler): + """Samples elements sequentially, always in the same order. + + Taken from https://github.com/PetrochukM/PyTorch-NLP + + Args: + data (iterable): Iterable data. + sort_key (callable): Specifies a function of one argument that is used to extract a + numerical comparison key from each list element. + + Example: + >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] + + """ + + def __init__(self, data, sort_key: Callable = identity): + super().__init__(data) + self.data = data + self.sort_key = sort_key + zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] + zip_ = sorted(zip_, key=lambda r: r[1]) + self.sorted_indexes = [item[0] for item in zip_] + + def __iter__(self): + return iter(self.sorted_indexes) + + def __len__(self): + return len(self.data) + + +class BucketBatchSampler(BatchSampler): + """Bucket batch sampler + + Adapted from https://github.com/PetrochukM/PyTorch-NLP + + Args: + sampler (torch.data.utils.sampler.Sampler): + batch_size (int): Size of mini-batch. + drop_last (bool): If `True` the sampler will drop the last batch if its size would be less + than `batch_size`. + data (list): List of data samples. + sort_key (callable, optional): Callable to specify a comparison key for sorting. + bucket_size_multiplier (int, optional): Buckets are of size + `batch_size * bucket_size_multiplier`. + + Example: + >>> sampler = WeightedRandomSampler(weights, len(weights)) + >>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True) + """ + + def __init__( + self, + sampler, + data, + batch_size, + drop_last, + sort_key: Union[Callable, List] = identity, + bucket_size_multiplier=100, + ): + super().__init__(sampler, batch_size, drop_last) + self.data = data + self.sort_key = sort_key + _bucket_size = batch_size * bucket_size_multiplier + if hasattr(sampler, "__len__"): + _bucket_size = min(_bucket_size, len(sampler)) + self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) + + def __iter__(self): + for idxs in self.bucket_sampler: + bucket_data = [self.data[idx] for idx in idxs] + sorted_sampler = SortedSampler(bucket_data, self.sort_key) + for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))): + sorted_idxs = [idxs[i] for i in batch_idx] + yield sorted_idxs + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + return math.ceil(len(self.sampler) / self.batch_size) diff --git a/tests/data_tests/test_samplers.py b/tests/data_tests/test_samplers.py index b85e0ec4..730d0d8b 100644 --- a/tests/data_tests/test_samplers.py +++ b/tests/data_tests/test_samplers.py @@ -5,11 +5,11 @@ import unittest import torch from TTS.config.shared_configs import BaseDatasetConfig -from TTS.encoder.utils.samplers import PerfectBatchSampler from TTS.tts.datasets import load_tts_samples from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights +from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler # Fixing random state to avoid random fails torch.manual_seed(0) @@ -163,3 +163,31 @@ class TestSamplers(unittest.TestCase): else: len2 += 1 assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced" + + def test_bucket_batch_sampler(self): + bucket_size_multiplier = 2 + sampler = range(len(train_samples)) + sampler = BucketBatchSampler( + sampler, + data=train_samples, + batch_size=7, + drop_last=True, + sort_key=lambda x: len(x["text"]), + bucket_size_multiplier=bucket_size_multiplier, + ) + + # check if the samples are sorted by text lenght whuile bucketing + min_text_len_in_bucket = 0 + bucket_items = [] + for batch_idx, batch in enumerate(list(sampler)): + if (batch_idx + 1) % bucket_size_multiplier == 0: + for bucket_item in bucket_items: + self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"])) + min_text_len_in_bucket = len(train_samples[bucket_item]["text"]) + min_text_len_in_bucket = 0 + bucket_items = [] + else: + bucket_items += batch + + # check sampler length + self.assertEqual(len(sampler), len(train_samples) // 7)