Implement bucketed weighted sampling for VITS (#1871)

This commit is contained in:
Eren Gölge 2022-08-15 11:08:11 +02:00 committed by GitHub
parent d46fbc240c
commit bfc63829ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 359 additions and 28 deletions

View File

@ -13,13 +13,13 @@ from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True

View File

@ -70,6 +70,18 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
use_weighted_sampler (bool):
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
weighted_sampler_attrs (dict):
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
by overweighting `root_path` by 2.0. Defaults to `{}`.
weighted_sampler_multipliers (dict):
Weight each unique value of a key returned by the formatter for weighted sampling.
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
@ -124,6 +136,11 @@ class VitsConfig(BaseTTSConfig):
return_wav: bool = True
compute_linear_spec: bool = True
# sampler params
use_weighted_sampler: bool = False # TODO: move it to the base config
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
# overrides
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

View File

@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
"audio_file": audio_path,
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
"root_path": root_path,
}
)
if not_found_counter > 0:
@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
wav_file = cols[1].strip()
text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file)
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file)
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file):
text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}
)
else:
# M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file))
@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2]
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path}
)
return items
@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
if speaker_name in ignored_speakers:
continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path}
)
return items
@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers:
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
items.append(
{
"text": text,
"audio_file": wav_file,
"speaker_name": f"LTTS_{speaker_name}",
"root_path": root_path,
}
)
for item in items:
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
return items
@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
skipped_files.append(wav_file)
continue
text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
return items
@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path})
return items
@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
else:
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
if os.path.exists(wav_file):
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
)
else:
print(f" [!] wav files don't exist - {wav_file}")
return items
@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non
with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path}
)
return items
@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar
if os.path.exists(txt_file) and os.path.exists(wav_file):
with open(txt_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
if ignore_digits_sentences and any(map(str.isdigit, text)):
continue
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id})
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path})
return items
@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker in ignored_speakers:
continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path}
)
return items
@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers:
continue
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
items.append(
{"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path}
)
return items
@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
for line in ttf:
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path})
return items
@ -554,5 +576,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2].replace(" ", "")
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torchaudio
@ -13,6 +14,8 @@ from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig
@ -29,6 +32,8 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec
from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results
@ -221,6 +226,30 @@ class VitsAudioConfig(Coqpit):
##############################
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
"""Create inverse frequency weights for balancing the dataset.
Use `multi_dict` to scale relative weights."""
attr_names_samples = np.array([item[attr_name] for item in items])
unique_attr_names = np.unique(attr_names_samples).tolist()
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
weight_attr = 1.0 / attr_count
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
if multi_dict is not None:
# check if all keys are in the multi_dict
for k in multi_dict:
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
# scale weights
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
dataset_samples_weight *= multiplier_samples
return (
torch.from_numpy(dataset_samples_weight).float(),
unique_attr_names,
np.unique(dataset_samples_weight).tolist(),
)
class VitsDataset(TTSDataset):
def __init__(self, model_args, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -1510,6 +1539,42 @@ class Vits(BaseTTS):
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
return batch
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
weights = None
data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict
)
weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
if weights is not None:
w_sampler = WeightedRandomSampler(weights, len(weights))
batch_sampler = BucketBatchSampler(
w_sampler,
data=data_items,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
sort_key=lambda x: os.path.getsize(x["audio_file"]),
drop_last=True,
)
else:
batch_sampler = None
# sampler for DDP
if batch_sampler is None:
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
else: # If a sampler is already defined use this sampler and DDP sampler together
batch_sampler = (
DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler
) # TODO: check batch_sampler with multi-gpu
return batch_sampler
def get_data_loader(
self,
config: Coqpit,
@ -1551,10 +1616,7 @@ class Vits(BaseTTS):
loader = DataLoader(
dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
shuffle=False, # shuffle is done in the dataset.
drop_last=False, # setting this False might cause issues in AMP training.
sampler=sampler,
batch_sampler=sampler,
collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False,
@ -1615,7 +1677,7 @@ class Vits(BaseTTS):
strict=True,
): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# compat band-aid for the pre-trained models to not use the encoder baked into the model
# TODO: consider baking the speaker encoder into the model and call it from there.
# as it is probably easier for model distribution.

202
TTS/utils/samplers.py Normal file
View File

@ -0,0 +1,202 @@
import math
import random
from typing import Callable, List, Union
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
class SubsetSampler(Sampler):
"""
Samples elements sequentially from a given list of indices.
Args:
indices (list): a sequence of indices
"""
def __init__(self, indices):
super().__init__(indices)
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
class PerfectBatchSampler(Sampler):
"""
Samples a mini-batch of indices for a balanced class batching
Args:
dataset_items(list): dataset items to sample from.
classes (list): list of classes of dataset_items to sample from.
batch_size (int): total number of samples to be sampled in a mini-batch.
num_gpus (int): number of GPU in the data parallel mode.
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
drop_last (bool): if True, drops last incomplete batch.
"""
def __init__(
self,
dataset_items,
classes,
batch_size,
num_classes_in_batch,
num_gpus=1,
shuffle=True,
drop_last=False,
label_key="class_name",
):
super().__init__(dataset_items)
assert (
batch_size % (num_classes_in_batch * num_gpus) == 0
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
label_indices = {}
for idx, item in enumerate(dataset_items):
label = item[label_key]
if label not in label_indices.keys():
label_indices[label] = [idx]
else:
label_indices[label].append(idx)
if shuffle:
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
else:
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
self._batch_size = batch_size
self._drop_last = drop_last
self._dp_devices = num_gpus
self._num_classes_in_batch = num_classes_in_batch
def __iter__(self):
batch = []
if self._num_classes_in_batch != len(self._samplers):
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
else:
valid_samplers_idx = None
iters = [iter(s) for s in self._samplers]
done = False
while True:
b = []
for i, it in enumerate(iters):
if valid_samplers_idx is not None and i not in valid_samplers_idx:
continue
idx = next(it, None)
if idx is None:
done = True
break
b.append(idx)
if done:
break
batch += b
if len(batch) == self._batch_size:
yield batch
batch = []
if valid_samplers_idx is not None:
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
if not self._drop_last:
if len(batch) > 0:
groups = len(batch) // self._num_classes_in_batch
if groups % self._dp_devices == 0:
yield batch
else:
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
if len(batch) > 0:
yield batch
def __len__(self):
class_batch_size = self._batch_size // self._num_classes_in_batch
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
def identity(x):
return x
class SortedSampler(Sampler):
"""Samples elements sequentially, always in the same order.
Taken from https://github.com/PetrochukM/PyTorch-NLP
Args:
data (iterable): Iterable data.
sort_key (callable): Specifies a function of one argument that is used to extract a
numerical comparison key from each list element.
Example:
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
"""
def __init__(self, data, sort_key: Callable = identity):
super().__init__(data)
self.data = data
self.sort_key = sort_key
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
zip_ = sorted(zip_, key=lambda r: r[1])
self.sorted_indexes = [item[0] for item in zip_]
def __iter__(self):
return iter(self.sorted_indexes)
def __len__(self):
return len(self.data)
class BucketBatchSampler(BatchSampler):
"""Bucket batch sampler
Adapted from https://github.com/PetrochukM/PyTorch-NLP
Args:
sampler (torch.data.utils.sampler.Sampler):
batch_size (int): Size of mini-batch.
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
than `batch_size`.
data (list): List of data samples.
sort_key (callable, optional): Callable to specify a comparison key for sorting.
bucket_size_multiplier (int, optional): Buckets are of size
`batch_size * bucket_size_multiplier`.
Example:
>>> sampler = WeightedRandomSampler(weights, len(weights))
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
"""
def __init__(
self,
sampler,
data,
batch_size,
drop_last,
sort_key: Union[Callable, List] = identity,
bucket_size_multiplier=100,
):
super().__init__(sampler, batch_size, drop_last)
self.data = data
self.sort_key = sort_key
_bucket_size = batch_size * bucket_size_multiplier
if hasattr(sampler, "__len__"):
_bucket_size = min(_bucket_size, len(sampler))
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
def __iter__(self):
for idxs in self.bucket_sampler:
bucket_data = [self.data[idx] for idx in idxs]
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
sorted_idxs = [idxs[i] for i in batch_idx]
yield sorted_idxs
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
return math.ceil(len(self.sampler) / self.batch_size)

View File

@ -5,11 +5,11 @@ import unittest
import torch
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
# Fixing random state to avoid random fails
torch.manual_seed(0)
@ -163,3 +163,31 @@ class TestSamplers(unittest.TestCase):
else:
len2 += 1
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
def test_bucket_batch_sampler(self):
bucket_size_multiplier = 2
sampler = range(len(train_samples))
sampler = BucketBatchSampler(
sampler,
data=train_samples,
batch_size=7,
drop_last=True,
sort_key=lambda x: len(x["text"]),
bucket_size_multiplier=bucket_size_multiplier,
)
# check if the samples are sorted by text lenght whuile bucketing
min_text_len_in_bucket = 0
bucket_items = []
for batch_idx, batch in enumerate(list(sampler)):
if (batch_idx + 1) % bucket_size_multiplier == 0:
for bucket_item in bucket_items:
self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
min_text_len_in_bucket = 0
bucket_items = []
else:
bucket_items += batch
# check sampler length
self.assertEqual(len(sampler), len(train_samples) // 7)