mirror of https://github.com/coqui-ai/TTS.git
Implement bucketed weighted sampling for VITS (#1871)
This commit is contained in:
parent
d46fbc240c
commit
bfc63829ac
|
@ -13,13 +13,13 @@ from trainer.trainer_utils import get_optimizer
|
|||
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
from TTS.encoder.utils.training import init_training
|
||||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.samplers import PerfectBatchSampler
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
|
|
@ -70,6 +70,18 @@ class VitsConfig(BaseTTSConfig):
|
|||
compute_linear_spec (bool):
|
||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||
|
||||
use_weighted_sampler (bool):
|
||||
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
||||
|
||||
weighted_sampler_attrs (dict):
|
||||
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
||||
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
||||
|
||||
weighted_sampler_multipliers (dict):
|
||||
Weight each unique value of a key returned by the formatter for weighted sampling.
|
||||
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
||||
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
||||
|
||||
r (int):
|
||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||
|
||||
|
@ -124,6 +136,11 @@ class VitsConfig(BaseTTSConfig):
|
|||
return_wav: bool = True
|
||||
compute_linear_spec: bool = True
|
||||
|
||||
# sampler params
|
||||
use_weighted_sampler: bool = False # TODO: move it to the base config
|
||||
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||
|
||||
# overrides
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
add_blank: bool = True
|
||||
|
|
|
@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
|
|||
"audio_file": audio_path,
|
||||
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
||||
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
||||
"root_path": root_path,
|
||||
}
|
||||
)
|
||||
if not_found_counter > 0:
|
||||
|
@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("\t")
|
||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
wav_file = cols[1].strip()
|
||||
text = cols[0].strip()
|
||||
wav_file = os.path.join(root_path, "wavs", wav_file)
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
|
|||
text = cols[1].strip()
|
||||
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
||||
wav_file = os.path.join(root_path, folder_name, wav_file)
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
|||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||
if os.path.isfile(wav_file):
|
||||
text = cols[1].strip()
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}
|
||||
)
|
||||
else:
|
||||
# M-AI-Labs have some missing samples, so just print the warning
|
||||
print("> File %s does not exist!" % (wav_file))
|
||||
|
@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
if not os.path.exists(wav_file):
|
||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
|
|||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
|
||||
items.append(
|
||||
{
|
||||
"text": text,
|
||||
"audio_file": wav_file,
|
||||
"speaker_name": f"LTTS_{speaker_name}",
|
||||
"root_path": root_path,
|
||||
}
|
||||
)
|
||||
for item in items:
|
||||
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
|
||||
return items
|
||||
|
@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
|
|||
skipped_files.append(wav_file)
|
||||
continue
|
||||
text = cols[1].strip()
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||
return items
|
||||
|
||||
|
@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
|
|||
else:
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
|
||||
if os.path.exists(wav_file):
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
|
||||
)
|
||||
else:
|
||||
print(f" [!] wav files don't exist - {wav_file}")
|
||||
return items
|
||||
|
@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non
|
|||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar
|
|||
if os.path.exists(txt_file) and os.path.exists(wav_file):
|
||||
with open(txt_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
|
|||
if ignore_digits_sentences and any(map(str.isdigit, text)):
|
||||
continue
|
||||
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker in ignored_speakers:
|
||||
continue
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
|
||||
items.append(
|
||||
{"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
|
||||
items.append(
|
||||
{"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path}
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
|
@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
|
|||
for line in ttf:
|
||||
wav_name, text = line.rstrip("\n").split("|")
|
||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -554,5 +576,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2].replace(" ", "")
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||
return items
|
||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass, field, replace
|
|||
from itertools import chain
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchaudio
|
||||
|
@ -13,6 +14,8 @@ from torch import nn
|
|||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.sampler import WeightedRandomSampler
|
||||
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||
|
@ -29,6 +32,8 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.samplers import BucketBatchSampler
|
||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||
from TTS.vocoder.utils.generic_utils import plot_results
|
||||
|
||||
|
@ -221,6 +226,30 @@ class VitsAudioConfig(Coqpit):
|
|||
##############################
|
||||
|
||||
|
||||
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
||||
"""Create inverse frequency weights for balancing the dataset.
|
||||
Use `multi_dict` to scale relative weights."""
|
||||
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||
weight_attr = 1.0 / attr_count
|
||||
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||
if multi_dict is not None:
|
||||
# check if all keys are in the multi_dict
|
||||
for k in multi_dict:
|
||||
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||
# scale weights
|
||||
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||
dataset_samples_weight *= multiplier_samples
|
||||
return (
|
||||
torch.from_numpy(dataset_samples_weight).float(),
|
||||
unique_attr_names,
|
||||
np.unique(dataset_samples_weight).tolist(),
|
||||
)
|
||||
|
||||
|
||||
class VitsDataset(TTSDataset):
|
||||
def __init__(self, model_args, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -1510,6 +1539,42 @@ class Vits(BaseTTS):
|
|||
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
||||
return batch
|
||||
|
||||
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
|
||||
weights = None
|
||||
data_items = dataset.samples
|
||||
if getattr(config, "use_weighted_sampler", False):
|
||||
for attr_name, alpha in config.weighted_sampler_attrs.items():
|
||||
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
|
||||
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
|
||||
print(multi_dict)
|
||||
weights, attr_names, attr_weights = get_attribute_balancer_weights(
|
||||
attr_name=attr_name, items=data_items, multi_dict=multi_dict
|
||||
)
|
||||
weights = weights * alpha
|
||||
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
|
||||
|
||||
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
|
||||
|
||||
if weights is not None:
|
||||
w_sampler = WeightedRandomSampler(weights, len(weights))
|
||||
batch_sampler = BucketBatchSampler(
|
||||
w_sampler,
|
||||
data=data_items,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
sort_key=lambda x: os.path.getsize(x["audio_file"]),
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
batch_sampler = None
|
||||
# sampler for DDP
|
||||
if batch_sampler is None:
|
||||
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||
batch_sampler = (
|
||||
DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler
|
||||
) # TODO: check batch_sampler with multi-gpu
|
||||
return batch_sampler
|
||||
|
||||
def get_data_loader(
|
||||
self,
|
||||
config: Coqpit,
|
||||
|
@ -1551,10 +1616,7 @@ class Vits(BaseTTS):
|
|||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False, # shuffle is done in the dataset.
|
||||
drop_last=False, # setting this False might cause issues in AMP training.
|
||||
sampler=sampler,
|
||||
batch_sampler=sampler,
|
||||
collate_fn=dataset.collate_fn,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
|
@ -1615,7 +1677,7 @@ class Vits(BaseTTS):
|
|||
strict=True,
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
"""Load the model checkpoint and setup for training or inference"""
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||
# as it is probably easier for model distribution.
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
import math
|
||||
import random
|
||||
from typing import Callable, List, Union
|
||||
|
||||
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
|
||||
|
||||
|
||||
class SubsetSampler(Sampler):
|
||||
"""
|
||||
Samples elements sequentially from a given list of indices.
|
||||
|
||||
Args:
|
||||
indices (list): a sequence of indices
|
||||
"""
|
||||
|
||||
def __init__(self, indices):
|
||||
super().__init__(indices)
|
||||
self.indices = indices
|
||||
|
||||
def __iter__(self):
|
||||
return (self.indices[i] for i in range(len(self.indices)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indices)
|
||||
|
||||
|
||||
class PerfectBatchSampler(Sampler):
|
||||
"""
|
||||
Samples a mini-batch of indices for a balanced class batching
|
||||
|
||||
Args:
|
||||
dataset_items(list): dataset items to sample from.
|
||||
classes (list): list of classes of dataset_items to sample from.
|
||||
batch_size (int): total number of samples to be sampled in a mini-batch.
|
||||
num_gpus (int): number of GPU in the data parallel mode.
|
||||
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
|
||||
drop_last (bool): if True, drops last incomplete batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_items,
|
||||
classes,
|
||||
batch_size,
|
||||
num_classes_in_batch,
|
||||
num_gpus=1,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
label_key="class_name",
|
||||
):
|
||||
super().__init__(dataset_items)
|
||||
assert (
|
||||
batch_size % (num_classes_in_batch * num_gpus) == 0
|
||||
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
|
||||
|
||||
label_indices = {}
|
||||
for idx, item in enumerate(dataset_items):
|
||||
label = item[label_key]
|
||||
if label not in label_indices.keys():
|
||||
label_indices[label] = [idx]
|
||||
else:
|
||||
label_indices[label].append(idx)
|
||||
|
||||
if shuffle:
|
||||
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
|
||||
else:
|
||||
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
|
||||
|
||||
self._batch_size = batch_size
|
||||
self._drop_last = drop_last
|
||||
self._dp_devices = num_gpus
|
||||
self._num_classes_in_batch = num_classes_in_batch
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
batch = []
|
||||
if self._num_classes_in_batch != len(self._samplers):
|
||||
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||
else:
|
||||
valid_samplers_idx = None
|
||||
|
||||
iters = [iter(s) for s in self._samplers]
|
||||
done = False
|
||||
|
||||
while True:
|
||||
b = []
|
||||
for i, it in enumerate(iters):
|
||||
if valid_samplers_idx is not None and i not in valid_samplers_idx:
|
||||
continue
|
||||
idx = next(it, None)
|
||||
if idx is None:
|
||||
done = True
|
||||
break
|
||||
b.append(idx)
|
||||
if done:
|
||||
break
|
||||
batch += b
|
||||
if len(batch) == self._batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if valid_samplers_idx is not None:
|
||||
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||
|
||||
if not self._drop_last:
|
||||
if len(batch) > 0:
|
||||
groups = len(batch) // self._num_classes_in_batch
|
||||
if groups % self._dp_devices == 0:
|
||||
yield batch
|
||||
else:
|
||||
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
class_batch_size = self._batch_size // self._num_classes_in_batch
|
||||
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
|
||||
|
||||
class SortedSampler(Sampler):
|
||||
"""Samples elements sequentially, always in the same order.
|
||||
|
||||
Taken from https://github.com/PetrochukM/PyTorch-NLP
|
||||
|
||||
Args:
|
||||
data (iterable): Iterable data.
|
||||
sort_key (callable): Specifies a function of one argument that is used to extract a
|
||||
numerical comparison key from each list element.
|
||||
|
||||
Example:
|
||||
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
|
||||
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data, sort_key: Callable = identity):
|
||||
super().__init__(data)
|
||||
self.data = data
|
||||
self.sort_key = sort_key
|
||||
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
|
||||
zip_ = sorted(zip_, key=lambda r: r[1])
|
||||
self.sorted_indexes = [item[0] for item in zip_]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.sorted_indexes)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
class BucketBatchSampler(BatchSampler):
|
||||
"""Bucket batch sampler
|
||||
|
||||
Adapted from https://github.com/PetrochukM/PyTorch-NLP
|
||||
|
||||
Args:
|
||||
sampler (torch.data.utils.sampler.Sampler):
|
||||
batch_size (int): Size of mini-batch.
|
||||
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
|
||||
than `batch_size`.
|
||||
data (list): List of data samples.
|
||||
sort_key (callable, optional): Callable to specify a comparison key for sorting.
|
||||
bucket_size_multiplier (int, optional): Buckets are of size
|
||||
`batch_size * bucket_size_multiplier`.
|
||||
|
||||
Example:
|
||||
>>> sampler = WeightedRandomSampler(weights, len(weights))
|
||||
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler,
|
||||
data,
|
||||
batch_size,
|
||||
drop_last,
|
||||
sort_key: Union[Callable, List] = identity,
|
||||
bucket_size_multiplier=100,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.data = data
|
||||
self.sort_key = sort_key
|
||||
_bucket_size = batch_size * bucket_size_multiplier
|
||||
if hasattr(sampler, "__len__"):
|
||||
_bucket_size = min(_bucket_size, len(sampler))
|
||||
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
|
||||
|
||||
def __iter__(self):
|
||||
for idxs in self.bucket_sampler:
|
||||
bucket_data = [self.data[idx] for idx in idxs]
|
||||
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
|
||||
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
|
||||
sorted_idxs = [idxs[i] for i in batch_idx]
|
||||
yield sorted_idxs
|
||||
|
||||
def __len__(self):
|
||||
if self.drop_last:
|
||||
return len(self.sampler) // self.batch_size
|
||||
return math.ceil(len(self.sampler) / self.batch_size)
|
|
@ -5,11 +5,11 @@ import unittest
|
|||
import torch
|
||||
|
||||
from TTS.config.shared_configs import BaseDatasetConfig
|
||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.utils.data import get_length_balancer_weights
|
||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
|
||||
|
||||
# Fixing random state to avoid random fails
|
||||
torch.manual_seed(0)
|
||||
|
@ -163,3 +163,31 @@ class TestSamplers(unittest.TestCase):
|
|||
else:
|
||||
len2 += 1
|
||||
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
|
||||
|
||||
def test_bucket_batch_sampler(self):
|
||||
bucket_size_multiplier = 2
|
||||
sampler = range(len(train_samples))
|
||||
sampler = BucketBatchSampler(
|
||||
sampler,
|
||||
data=train_samples,
|
||||
batch_size=7,
|
||||
drop_last=True,
|
||||
sort_key=lambda x: len(x["text"]),
|
||||
bucket_size_multiplier=bucket_size_multiplier,
|
||||
)
|
||||
|
||||
# check if the samples are sorted by text lenght whuile bucketing
|
||||
min_text_len_in_bucket = 0
|
||||
bucket_items = []
|
||||
for batch_idx, batch in enumerate(list(sampler)):
|
||||
if (batch_idx + 1) % bucket_size_multiplier == 0:
|
||||
for bucket_item in bucket_items:
|
||||
self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
|
||||
min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
|
||||
min_text_len_in_bucket = 0
|
||||
bucket_items = []
|
||||
else:
|
||||
bucket_items += batch
|
||||
|
||||
# check sampler length
|
||||
self.assertEqual(len(sampler), len(train_samples) // 7)
|
||||
|
|
Loading…
Reference in New Issue