Implement bucketed weighted sampling for VITS (#1871)

This commit is contained in:
Eren Gölge 2022-08-15 11:08:11 +02:00 committed by GitHub
parent d46fbc240c
commit bfc63829ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 359 additions and 28 deletions

View File

@ -13,13 +13,13 @@ from trainer.trainer_utils import get_optimizer
from TTS.encoder.dataset import EncoderDataset from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.encoder.utils.training import init_training from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings from TTS.encoder.utils.visual import plot_embeddings
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
from TTS.utils.io import copy_model_files from TTS.utils.io import copy_model_files
from TTS.utils.samplers import PerfectBatchSampler
from TTS.utils.training import check_update from TTS.utils.training import check_update
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True

View File

@ -70,6 +70,18 @@ class VitsConfig(BaseTTSConfig):
compute_linear_spec (bool): compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
use_weighted_sampler (bool):
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
weighted_sampler_attrs (dict):
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
by overweighting `root_path` by 2.0. Defaults to `{}`.
weighted_sampler_multipliers (dict):
Weight each unique value of a key returned by the formatter for weighted sampling.
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
r (int): r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
@ -124,6 +136,11 @@ class VitsConfig(BaseTTSConfig):
return_wav: bool = True return_wav: bool = True
compute_linear_spec: bool = True compute_linear_spec: bool = True
# sampler params
use_weighted_sampler: bool = False # TODO: move it to the base config
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
# overrides # overrides
r: int = 1 # DO NOT CHANGE r: int = 1 # DO NOT CHANGE
add_blank: bool = True add_blank: bool = True

View File

@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
"audio_file": audio_path, "audio_file": audio_path,
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name, "speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name, "emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
"root_path": root_path,
} }
) )
if not_found_counter > 0: if not_found_counter > 0:
@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("\t") cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav") wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
wav_file = cols[1].strip() wav_file = cols[1].strip()
text = cols[0].strip() text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file) wav_file = os.path.join(root_path, "wavs", wav_file)
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
text = cols[1].strip() text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file) wav_file = os.path.join(root_path, folder_name, wav_file)
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file): if os.path.isfile(wav_file):
text = cols[1].strip() text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append(
{"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}
)
else: else:
# M-AI-Labs have some missing samples, so just print the warning # M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file)) print("> File %s does not exist!" % (wav_file))
@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2] text = cols[2]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2] text = cols[2]
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"}) items.append(
{"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path}
)
return items return items
@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
if not os.path.exists(wav_file): if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...") print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name}) items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path}
)
return items return items
@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"}) items.append(
{
"text": text,
"audio_file": wav_file,
"speaker_name": f"LTTS_{speaker_name}",
"root_path": root_path,
}
)
for item in items: for item in items:
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}" assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
return items return items
@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
skipped_files.append(wav_file) skipped_files.append(wav_file)
continue continue
text = cols[1].strip() text = cols[1].strip()
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...") print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
return items return items
@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers: if speaker_id in ignored_speakers:
continue continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path})
return items return items
@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
else: else:
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}") wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
if os.path.exists(wav_file): if os.path.exists(wav_file):
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id}) items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
)
else: else:
print(f" [!] wav files don't exist - {wav_file}") print(f" [!] wav files don't exist - {wav_file}")
return items return items
@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non
with open(meta_file, "r", encoding="utf-8") as file_text: with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0] text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id}) items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path}
)
return items return items
@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar
if os.path.exists(txt_file) and os.path.exists(wav_file): if os.path.exists(txt_file) and os.path.exists(wav_file):
with open(txt_file, "r", encoding="utf-8") as file_text: with open(txt_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0] text = file_text.readlines()[0]
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
if ignore_digits_sentences and any(map(str.isdigit, text)): if ignore_digits_sentences and any(map(str.isdigit, text)):
continue continue
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac") wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id}) items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path})
return items return items
@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker in ignored_speakers: if speaker in ignored_speakers:
continue continue
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker}) items.append(
{"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path}
)
return items return items
@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers: if speaker_id in ignored_speakers:
continue continue
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id}) items.append(
{"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path}
)
return items return items
@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
for line in ttf: for line in ttf:
wav_name, text = line.rstrip("\n").split("|") wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name) wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path})
return items return items
@ -554,5 +576,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2].replace(" ", "") text = cols[2].replace(" ", "")
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name}) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
return items return items

View File

@ -4,6 +4,7 @@ from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torchaudio import torchaudio
@ -13,6 +14,8 @@ from torch import nn
from torch.cuda.amp.autocast_mode import autocast from torch.cuda.amp.autocast_mode import autocast
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
from trainer.trainer_utils import get_optimizer, get_scheduler from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.configs.shared_configs import CharactersConfig
@ -29,6 +32,8 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment from TTS.tts.utils.visual import plot_alignment
from TTS.utils.io import load_fsspec
from TTS.utils.samplers import BucketBatchSampler
from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.models.hifigan_generator import HifiganGenerator
from TTS.vocoder.utils.generic_utils import plot_results from TTS.vocoder.utils.generic_utils import plot_results
@ -221,6 +226,30 @@ class VitsAudioConfig(Coqpit):
############################## ##############################
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
"""Create inverse frequency weights for balancing the dataset.
Use `multi_dict` to scale relative weights."""
attr_names_samples = np.array([item[attr_name] for item in items])
unique_attr_names = np.unique(attr_names_samples).tolist()
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
weight_attr = 1.0 / attr_count
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
if multi_dict is not None:
# check if all keys are in the multi_dict
for k in multi_dict:
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
# scale weights
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
dataset_samples_weight *= multiplier_samples
return (
torch.from_numpy(dataset_samples_weight).float(),
unique_attr_names,
np.unique(dataset_samples_weight).tolist(),
)
class VitsDataset(TTSDataset): class VitsDataset(TTSDataset):
def __init__(self, model_args, *args, **kwargs): def __init__(self, model_args, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -1510,6 +1539,42 @@ class Vits(BaseTTS):
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1) batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
return batch return batch
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
weights = None
data_items = dataset.samples
if getattr(config, "use_weighted_sampler", False):
for attr_name, alpha in config.weighted_sampler_attrs.items():
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
print(multi_dict)
weights, attr_names, attr_weights = get_attribute_balancer_weights(
attr_name=attr_name, items=data_items, multi_dict=multi_dict
)
weights = weights * alpha
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
if weights is not None:
w_sampler = WeightedRandomSampler(weights, len(weights))
batch_sampler = BucketBatchSampler(
w_sampler,
data=data_items,
batch_size=config.eval_batch_size if is_eval else config.batch_size,
sort_key=lambda x: os.path.getsize(x["audio_file"]),
drop_last=True,
)
else:
batch_sampler = None
# sampler for DDP
if batch_sampler is None:
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
else: # If a sampler is already defined use this sampler and DDP sampler together
batch_sampler = (
DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler
) # TODO: check batch_sampler with multi-gpu
return batch_sampler
def get_data_loader( def get_data_loader(
self, self,
config: Coqpit, config: Coqpit,
@ -1551,10 +1616,7 @@ class Vits(BaseTTS):
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=config.eval_batch_size if is_eval else config.batch_size, batch_sampler=sampler,
shuffle=False, # shuffle is done in the dataset.
drop_last=False, # setting this False might cause issues in AMP training.
sampler=sampler,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers, num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
pin_memory=False, pin_memory=False,
@ -1615,7 +1677,7 @@ class Vits(BaseTTS):
strict=True, strict=True,
): # pylint: disable=unused-argument, redefined-builtin ): # pylint: disable=unused-argument, redefined-builtin
"""Load the model checkpoint and setup for training or inference""" """Load the model checkpoint and setup for training or inference"""
state = torch.load(checkpoint_path, map_location=torch.device("cpu")) state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
# compat band-aid for the pre-trained models to not use the encoder baked into the model # compat band-aid for the pre-trained models to not use the encoder baked into the model
# TODO: consider baking the speaker encoder into the model and call it from there. # TODO: consider baking the speaker encoder into the model and call it from there.
# as it is probably easier for model distribution. # as it is probably easier for model distribution.

202
TTS/utils/samplers.py Normal file
View File

@ -0,0 +1,202 @@
import math
import random
from typing import Callable, List, Union
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
class SubsetSampler(Sampler):
"""
Samples elements sequentially from a given list of indices.
Args:
indices (list): a sequence of indices
"""
def __init__(self, indices):
super().__init__(indices)
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
class PerfectBatchSampler(Sampler):
"""
Samples a mini-batch of indices for a balanced class batching
Args:
dataset_items(list): dataset items to sample from.
classes (list): list of classes of dataset_items to sample from.
batch_size (int): total number of samples to be sampled in a mini-batch.
num_gpus (int): number of GPU in the data parallel mode.
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
drop_last (bool): if True, drops last incomplete batch.
"""
def __init__(
self,
dataset_items,
classes,
batch_size,
num_classes_in_batch,
num_gpus=1,
shuffle=True,
drop_last=False,
label_key="class_name",
):
super().__init__(dataset_items)
assert (
batch_size % (num_classes_in_batch * num_gpus) == 0
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
label_indices = {}
for idx, item in enumerate(dataset_items):
label = item[label_key]
if label not in label_indices.keys():
label_indices[label] = [idx]
else:
label_indices[label].append(idx)
if shuffle:
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
else:
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
self._batch_size = batch_size
self._drop_last = drop_last
self._dp_devices = num_gpus
self._num_classes_in_batch = num_classes_in_batch
def __iter__(self):
batch = []
if self._num_classes_in_batch != len(self._samplers):
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
else:
valid_samplers_idx = None
iters = [iter(s) for s in self._samplers]
done = False
while True:
b = []
for i, it in enumerate(iters):
if valid_samplers_idx is not None and i not in valid_samplers_idx:
continue
idx = next(it, None)
if idx is None:
done = True
break
b.append(idx)
if done:
break
batch += b
if len(batch) == self._batch_size:
yield batch
batch = []
if valid_samplers_idx is not None:
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
if not self._drop_last:
if len(batch) > 0:
groups = len(batch) // self._num_classes_in_batch
if groups % self._dp_devices == 0:
yield batch
else:
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
if len(batch) > 0:
yield batch
def __len__(self):
class_batch_size = self._batch_size // self._num_classes_in_batch
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
def identity(x):
return x
class SortedSampler(Sampler):
"""Samples elements sequentially, always in the same order.
Taken from https://github.com/PetrochukM/PyTorch-NLP
Args:
data (iterable): Iterable data.
sort_key (callable): Specifies a function of one argument that is used to extract a
numerical comparison key from each list element.
Example:
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
"""
def __init__(self, data, sort_key: Callable = identity):
super().__init__(data)
self.data = data
self.sort_key = sort_key
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
zip_ = sorted(zip_, key=lambda r: r[1])
self.sorted_indexes = [item[0] for item in zip_]
def __iter__(self):
return iter(self.sorted_indexes)
def __len__(self):
return len(self.data)
class BucketBatchSampler(BatchSampler):
"""Bucket batch sampler
Adapted from https://github.com/PetrochukM/PyTorch-NLP
Args:
sampler (torch.data.utils.sampler.Sampler):
batch_size (int): Size of mini-batch.
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
than `batch_size`.
data (list): List of data samples.
sort_key (callable, optional): Callable to specify a comparison key for sorting.
bucket_size_multiplier (int, optional): Buckets are of size
`batch_size * bucket_size_multiplier`.
Example:
>>> sampler = WeightedRandomSampler(weights, len(weights))
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
"""
def __init__(
self,
sampler,
data,
batch_size,
drop_last,
sort_key: Union[Callable, List] = identity,
bucket_size_multiplier=100,
):
super().__init__(sampler, batch_size, drop_last)
self.data = data
self.sort_key = sort_key
_bucket_size = batch_size * bucket_size_multiplier
if hasattr(sampler, "__len__"):
_bucket_size = min(_bucket_size, len(sampler))
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
def __iter__(self):
for idxs in self.bucket_sampler:
bucket_data = [self.data[idx] for idx in idxs]
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
sorted_idxs = [idxs[i] for i in batch_idx]
yield sorted_idxs
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
return math.ceil(len(self.sampler) / self.batch_size)

View File

@ -5,11 +5,11 @@ import unittest
import torch import torch
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.tts.datasets import load_tts_samples from TTS.tts.datasets import load_tts_samples
from TTS.tts.utils.data import get_length_balancer_weights from TTS.tts.utils.data import get_length_balancer_weights
from TTS.tts.utils.languages import get_language_balancer_weights from TTS.tts.utils.languages import get_language_balancer_weights
from TTS.tts.utils.speakers import get_speaker_balancer_weights from TTS.tts.utils.speakers import get_speaker_balancer_weights
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
# Fixing random state to avoid random fails # Fixing random state to avoid random fails
torch.manual_seed(0) torch.manual_seed(0)
@ -163,3 +163,31 @@ class TestSamplers(unittest.TestCase):
else: else:
len2 += 1 len2 += 1
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced" assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
def test_bucket_batch_sampler(self):
bucket_size_multiplier = 2
sampler = range(len(train_samples))
sampler = BucketBatchSampler(
sampler,
data=train_samples,
batch_size=7,
drop_last=True,
sort_key=lambda x: len(x["text"]),
bucket_size_multiplier=bucket_size_multiplier,
)
# check if the samples are sorted by text lenght whuile bucketing
min_text_len_in_bucket = 0
bucket_items = []
for batch_idx, batch in enumerate(list(sampler)):
if (batch_idx + 1) % bucket_size_multiplier == 0:
for bucket_item in bucket_items:
self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
min_text_len_in_bucket = 0
bucket_items = []
else:
bucket_items += batch
# check sampler length
self.assertEqual(len(sampler), len(train_samples) // 7)