mirror of https://github.com/coqui-ai/TTS.git
Implement bucketed weighted sampling for VITS (#1871)
This commit is contained in:
parent
d46fbc240c
commit
bfc63829ac
|
@ -13,13 +13,13 @@ from trainer.trainer_utils import get_optimizer
|
||||||
|
|
||||||
from TTS.encoder.dataset import EncoderDataset
|
from TTS.encoder.dataset import EncoderDataset
|
||||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
||||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
|
||||||
from TTS.encoder.utils.training import init_training
|
from TTS.encoder.utils.training import init_training
|
||||||
from TTS.encoder.utils.visual import plot_embeddings
|
from TTS.encoder.utils.visual import plot_embeddings
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||||
from TTS.utils.io import copy_model_files
|
from TTS.utils.io import copy_model_files
|
||||||
|
from TTS.utils.samplers import PerfectBatchSampler
|
||||||
from TTS.utils.training import check_update
|
from TTS.utils.training import check_update
|
||||||
|
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
|
|
|
@ -70,6 +70,18 @@ class VitsConfig(BaseTTSConfig):
|
||||||
compute_linear_spec (bool):
|
compute_linear_spec (bool):
|
||||||
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
|
||||||
|
|
||||||
|
use_weighted_sampler (bool):
|
||||||
|
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
|
||||||
|
|
||||||
|
weighted_sampler_attrs (dict):
|
||||||
|
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
|
||||||
|
by overweighting `root_path` by 2.0. Defaults to `{}`.
|
||||||
|
|
||||||
|
weighted_sampler_multipliers (dict):
|
||||||
|
Weight each unique value of a key returned by the formatter for weighted sampling.
|
||||||
|
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
|
||||||
|
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
|
||||||
|
|
||||||
r (int):
|
r (int):
|
||||||
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
|
||||||
|
|
||||||
|
@ -124,6 +136,11 @@ class VitsConfig(BaseTTSConfig):
|
||||||
return_wav: bool = True
|
return_wav: bool = True
|
||||||
compute_linear_spec: bool = True
|
compute_linear_spec: bool = True
|
||||||
|
|
||||||
|
# sampler params
|
||||||
|
use_weighted_sampler: bool = False # TODO: move it to the base config
|
||||||
|
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
|
||||||
|
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
# overrides
|
# overrides
|
||||||
r: int = 1 # DO NOT CHANGE
|
r: int = 1 # DO NOT CHANGE
|
||||||
add_blank: bool = True
|
add_blank: bool = True
|
||||||
|
|
|
@ -34,6 +34,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
|
||||||
"audio_file": audio_path,
|
"audio_file": audio_path,
|
||||||
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
||||||
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
||||||
|
"root_path": root_path,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if not_found_counter > 0:
|
if not_found_counter > 0:
|
||||||
|
@ -53,7 +54,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("\t")
|
cols = line.split("\t")
|
||||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,7 +69,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
wav_file = cols[1].strip()
|
wav_file = cols[1].strip()
|
||||||
text = cols[0].strip()
|
text = cols[0].strip()
|
||||||
wav_file = os.path.join(root_path, "wavs", wav_file)
|
wav_file = os.path.join(root_path, "wavs", wav_file)
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +85,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
|
||||||
text = cols[1].strip()
|
text = cols[1].strip()
|
||||||
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
||||||
wav_file = os.path.join(root_path, folder_name, wav_file)
|
wav_file = os.path.join(root_path, folder_name, wav_file)
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,7 +131,9 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
||||||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||||
if os.path.isfile(wav_file):
|
if os.path.isfile(wav_file):
|
||||||
text = cols[1].strip()
|
text = cols[1].strip()
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# M-AI-Labs have some missing samples, so just print the warning
|
# M-AI-Labs have some missing samples, so just print the warning
|
||||||
print("> File %s does not exist!" % (wav_file))
|
print("> File %s does not exist!" % (wav_file))
|
||||||
|
@ -148,7 +151,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
text = cols[2]
|
text = cols[2]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,7 +169,9 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
text = cols[2]
|
text = cols[2]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}"})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{speaker_id}", "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,7 +186,7 @@ def thorsten(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -198,7 +203,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
||||||
if not os.path.exists(wav_file):
|
if not os.path.exists(wav_file):
|
||||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||||
continue
|
continue
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -213,7 +218,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
||||||
text = cols[1]
|
text = cols[1]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -261,7 +266,9 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
|
||||||
if speaker_name in ignored_speakers:
|
if speaker_name in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name, "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -288,7 +295,14 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker_name in ignored_speakers:
|
if speaker_name in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
|
items.append(
|
||||||
|
{
|
||||||
|
"text": text,
|
||||||
|
"audio_file": wav_file,
|
||||||
|
"speaker_name": f"LTTS_{speaker_name}",
|
||||||
|
"root_path": root_path,
|
||||||
|
}
|
||||||
|
)
|
||||||
for item in items:
|
for item in items:
|
||||||
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
|
assert os.path.exists(item["audio_file"]), f" [!] wav files don't exist - {item['audio_file']}"
|
||||||
return items
|
return items
|
||||||
|
@ -307,7 +321,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
|
||||||
skipped_files.append(wav_file)
|
skipped_files.append(wav_file)
|
||||||
continue
|
continue
|
||||||
text = cols[1].strip()
|
text = cols[1].strip()
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
@ -329,7 +343,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker_id in ignored_speakers:
|
if speaker_id in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -372,7 +386,9 @@ def vctk(root_path, meta_files=None, wavs_path="wav48_silence_trimmed", mic="mic
|
||||||
else:
|
else:
|
||||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
|
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + f"_{mic}.{file_ext}")
|
||||||
if os.path.exists(wav_file):
|
if os.path.exists(wav_file):
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id, "root_path": root_path}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(f" [!] wav files don't exist - {wav_file}")
|
print(f" [!] wav files don't exist - {wav_file}")
|
||||||
return items
|
return items
|
||||||
|
@ -392,7 +408,9 @@ def vctk_old(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=Non
|
||||||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||||
text = file_text.readlines()[0]
|
text = file_text.readlines()[0]
|
||||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": "VCTK_old_" + speaker_id, "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -411,7 +429,7 @@ def synpaflex(root_path, metafiles=None, **kwargs): # pylint: disable=unused-ar
|
||||||
if os.path.exists(txt_file) and os.path.exists(wav_file):
|
if os.path.exists(txt_file) and os.path.exists(wav_file):
|
||||||
with open(txt_file, "r", encoding="utf-8") as file_text:
|
with open(txt_file, "r", encoding="utf-8") as file_text:
|
||||||
text = file_text.readlines()[0]
|
text = file_text.readlines()[0]
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -433,7 +451,7 @@ def open_bible(root_path, meta_files="train", ignore_digits_sentences=True, igno
|
||||||
if ignore_digits_sentences and any(map(str.isdigit, text)):
|
if ignore_digits_sentences and any(map(str.isdigit, text)):
|
||||||
continue
|
continue
|
||||||
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
|
wav_file = os.path.join(root_path, split_dir, speaker_id, file_id + ".flac")
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": "OB_" + speaker_id, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -450,7 +468,9 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker in ignored_speakers:
|
if speaker in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
|
items.append(
|
||||||
|
{"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker, "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -520,7 +540,9 @@ def emotion(root_path, meta_file, ignored_speakers=None):
|
||||||
if isinstance(ignored_speakers, list):
|
if isinstance(ignored_speakers, list):
|
||||||
if speaker_id in ignored_speakers:
|
if speaker_id in ignored_speakers:
|
||||||
continue
|
continue
|
||||||
items.append({"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id})
|
items.append(
|
||||||
|
{"audio_file": wav_file, "speaker_name": speaker_id, "emotion_name": emotion_id, "root_path": root_path}
|
||||||
|
)
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -540,7 +562,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
|
||||||
for line in ttf:
|
for line in ttf:
|
||||||
wav_name, text = line.rstrip("\n").split("|")
|
wav_name, text = line.rstrip("\n").split("|")
|
||||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||||
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
@ -554,5 +576,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
||||||
cols = line.split("|")
|
cols = line.split("|")
|
||||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||||
text = cols[2].replace(" ", "")
|
text = cols[2].replace(" ", "")
|
||||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
|
||||||
return items
|
return items
|
||||||
|
|
|
@ -4,6 +4,7 @@ from dataclasses import dataclass, field, replace
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
@ -13,6 +14,8 @@ from torch import nn
|
||||||
from torch.cuda.amp.autocast_mode import autocast
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
from trainer.torch import DistributedSampler, DistributedSamplerWrapper
|
||||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import CharactersConfig
|
from TTS.tts.configs.shared_configs import CharactersConfig
|
||||||
|
@ -29,6 +32,8 @@ from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations
|
||||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||||
from TTS.tts.utils.visual import plot_alignment
|
from TTS.tts.utils.visual import plot_alignment
|
||||||
|
from TTS.utils.io import load_fsspec
|
||||||
|
from TTS.utils.samplers import BucketBatchSampler
|
||||||
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
from TTS.vocoder.models.hifigan_generator import HifiganGenerator
|
||||||
from TTS.vocoder.utils.generic_utils import plot_results
|
from TTS.vocoder.utils.generic_utils import plot_results
|
||||||
|
|
||||||
|
@ -221,6 +226,30 @@ class VitsAudioConfig(Coqpit):
|
||||||
##############################
|
##############################
|
||||||
|
|
||||||
|
|
||||||
|
def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None):
|
||||||
|
"""Create inverse frequency weights for balancing the dataset.
|
||||||
|
Use `multi_dict` to scale relative weights."""
|
||||||
|
attr_names_samples = np.array([item[attr_name] for item in items])
|
||||||
|
unique_attr_names = np.unique(attr_names_samples).tolist()
|
||||||
|
attr_idx = [unique_attr_names.index(l) for l in attr_names_samples]
|
||||||
|
attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names])
|
||||||
|
weight_attr = 1.0 / attr_count
|
||||||
|
dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx])
|
||||||
|
dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight)
|
||||||
|
if multi_dict is not None:
|
||||||
|
# check if all keys are in the multi_dict
|
||||||
|
for k in multi_dict:
|
||||||
|
assert k in unique_attr_names, f"{k} not in {unique_attr_names}"
|
||||||
|
# scale weights
|
||||||
|
multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items])
|
||||||
|
dataset_samples_weight *= multiplier_samples
|
||||||
|
return (
|
||||||
|
torch.from_numpy(dataset_samples_weight).float(),
|
||||||
|
unique_attr_names,
|
||||||
|
np.unique(dataset_samples_weight).tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VitsDataset(TTSDataset):
|
class VitsDataset(TTSDataset):
|
||||||
def __init__(self, model_args, *args, **kwargs):
|
def __init__(self, model_args, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -1510,6 +1539,42 @@ class Vits(BaseTTS):
|
||||||
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
batch["mel"] = batch["mel"] * sequence_mask(batch["mel_lens"]).unsqueeze(1)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False):
|
||||||
|
weights = None
|
||||||
|
data_items = dataset.samples
|
||||||
|
if getattr(config, "use_weighted_sampler", False):
|
||||||
|
for attr_name, alpha in config.weighted_sampler_attrs.items():
|
||||||
|
print(f" > Using weighted sampler for attribute '{attr_name}' with alpha '{alpha}'")
|
||||||
|
multi_dict = config.weighted_sampler_multipliers.get(attr_name, None)
|
||||||
|
print(multi_dict)
|
||||||
|
weights, attr_names, attr_weights = get_attribute_balancer_weights(
|
||||||
|
attr_name=attr_name, items=data_items, multi_dict=multi_dict
|
||||||
|
)
|
||||||
|
weights = weights * alpha
|
||||||
|
print(f" > Attribute weights for '{attr_names}' \n | > {attr_weights}")
|
||||||
|
|
||||||
|
# input_audio_lenghts = [os.path.getsize(x["audio_file"]) for x in data_items]
|
||||||
|
|
||||||
|
if weights is not None:
|
||||||
|
w_sampler = WeightedRandomSampler(weights, len(weights))
|
||||||
|
batch_sampler = BucketBatchSampler(
|
||||||
|
w_sampler,
|
||||||
|
data=data_items,
|
||||||
|
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||||
|
sort_key=lambda x: os.path.getsize(x["audio_file"]),
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
batch_sampler = None
|
||||||
|
# sampler for DDP
|
||||||
|
if batch_sampler is None:
|
||||||
|
batch_sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
else: # If a sampler is already defined use this sampler and DDP sampler together
|
||||||
|
batch_sampler = (
|
||||||
|
DistributedSamplerWrapper(batch_sampler) if num_gpus > 1 else batch_sampler
|
||||||
|
) # TODO: check batch_sampler with multi-gpu
|
||||||
|
return batch_sampler
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
self,
|
self,
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
|
@ -1551,10 +1616,7 @@ class Vits(BaseTTS):
|
||||||
|
|
||||||
loader = DataLoader(
|
loader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
batch_sampler=sampler,
|
||||||
shuffle=False, # shuffle is done in the dataset.
|
|
||||||
drop_last=False, # setting this False might cause issues in AMP training.
|
|
||||||
sampler=sampler,
|
|
||||||
collate_fn=dataset.collate_fn,
|
collate_fn=dataset.collate_fn,
|
||||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||||
pin_memory=False,
|
pin_memory=False,
|
||||||
|
@ -1615,7 +1677,7 @@ class Vits(BaseTTS):
|
||||||
strict=True,
|
strict=True,
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
"""Load the model checkpoint and setup for training or inference"""
|
"""Load the model checkpoint and setup for training or inference"""
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
# compat band-aid for the pre-trained models to not use the encoder baked into the model
|
||||||
# TODO: consider baking the speaker encoder into the model and call it from there.
|
# TODO: consider baking the speaker encoder into the model and call it from there.
|
||||||
# as it is probably easier for model distribution.
|
# as it is probably easier for model distribution.
|
||||||
|
|
|
@ -0,0 +1,202 @@
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
from typing import Callable, List, Union
|
||||||
|
|
||||||
|
from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
|
||||||
|
|
||||||
|
|
||||||
|
class SubsetSampler(Sampler):
|
||||||
|
"""
|
||||||
|
Samples elements sequentially from a given list of indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indices (list): a sequence of indices
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, indices):
|
||||||
|
super().__init__(indices)
|
||||||
|
self.indices = indices
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return (self.indices[i] for i in range(len(self.indices)))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.indices)
|
||||||
|
|
||||||
|
|
||||||
|
class PerfectBatchSampler(Sampler):
|
||||||
|
"""
|
||||||
|
Samples a mini-batch of indices for a balanced class batching
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_items(list): dataset items to sample from.
|
||||||
|
classes (list): list of classes of dataset_items to sample from.
|
||||||
|
batch_size (int): total number of samples to be sampled in a mini-batch.
|
||||||
|
num_gpus (int): number of GPU in the data parallel mode.
|
||||||
|
shuffle (bool): if True, samples randomly, otherwise samples sequentially.
|
||||||
|
drop_last (bool): if True, drops last incomplete batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_items,
|
||||||
|
classes,
|
||||||
|
batch_size,
|
||||||
|
num_classes_in_batch,
|
||||||
|
num_gpus=1,
|
||||||
|
shuffle=True,
|
||||||
|
drop_last=False,
|
||||||
|
label_key="class_name",
|
||||||
|
):
|
||||||
|
super().__init__(dataset_items)
|
||||||
|
assert (
|
||||||
|
batch_size % (num_classes_in_batch * num_gpus) == 0
|
||||||
|
), "Batch size must be divisible by number of classes times the number of data parallel devices (if enabled)."
|
||||||
|
|
||||||
|
label_indices = {}
|
||||||
|
for idx, item in enumerate(dataset_items):
|
||||||
|
label = item[label_key]
|
||||||
|
if label not in label_indices.keys():
|
||||||
|
label_indices[label] = [idx]
|
||||||
|
else:
|
||||||
|
label_indices[label].append(idx)
|
||||||
|
|
||||||
|
if shuffle:
|
||||||
|
self._samplers = [SubsetRandomSampler(label_indices[key]) for key in classes]
|
||||||
|
else:
|
||||||
|
self._samplers = [SubsetSampler(label_indices[key]) for key in classes]
|
||||||
|
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self._drop_last = drop_last
|
||||||
|
self._dp_devices = num_gpus
|
||||||
|
self._num_classes_in_batch = num_classes_in_batch
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
if self._num_classes_in_batch != len(self._samplers):
|
||||||
|
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||||
|
else:
|
||||||
|
valid_samplers_idx = None
|
||||||
|
|
||||||
|
iters = [iter(s) for s in self._samplers]
|
||||||
|
done = False
|
||||||
|
|
||||||
|
while True:
|
||||||
|
b = []
|
||||||
|
for i, it in enumerate(iters):
|
||||||
|
if valid_samplers_idx is not None and i not in valid_samplers_idx:
|
||||||
|
continue
|
||||||
|
idx = next(it, None)
|
||||||
|
if idx is None:
|
||||||
|
done = True
|
||||||
|
break
|
||||||
|
b.append(idx)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
batch += b
|
||||||
|
if len(batch) == self._batch_size:
|
||||||
|
yield batch
|
||||||
|
batch = []
|
||||||
|
if valid_samplers_idx is not None:
|
||||||
|
valid_samplers_idx = random.sample(range(len(self._samplers)), self._num_classes_in_batch)
|
||||||
|
|
||||||
|
if not self._drop_last:
|
||||||
|
if len(batch) > 0:
|
||||||
|
groups = len(batch) // self._num_classes_in_batch
|
||||||
|
if groups % self._dp_devices == 0:
|
||||||
|
yield batch
|
||||||
|
else:
|
||||||
|
batch = batch[: (groups // self._dp_devices) * self._dp_devices * self._num_classes_in_batch]
|
||||||
|
if len(batch) > 0:
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
class_batch_size = self._batch_size // self._num_classes_in_batch
|
||||||
|
return min(((len(s) + class_batch_size - 1) // class_batch_size) for s in self._samplers)
|
||||||
|
|
||||||
|
|
||||||
|
def identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SortedSampler(Sampler):
|
||||||
|
"""Samples elements sequentially, always in the same order.
|
||||||
|
|
||||||
|
Taken from https://github.com/PetrochukM/PyTorch-NLP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (iterable): Iterable data.
|
||||||
|
sort_key (callable): Specifies a function of one argument that is used to extract a
|
||||||
|
numerical comparison key from each list element.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> list(SortedSampler(range(10), sort_key=lambda i: -i))
|
||||||
|
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data, sort_key: Callable = identity):
|
||||||
|
super().__init__(data)
|
||||||
|
self.data = data
|
||||||
|
self.sort_key = sort_key
|
||||||
|
zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
|
||||||
|
zip_ = sorted(zip_, key=lambda r: r[1])
|
||||||
|
self.sorted_indexes = [item[0] for item in zip_]
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self.sorted_indexes)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
|
||||||
|
class BucketBatchSampler(BatchSampler):
|
||||||
|
"""Bucket batch sampler
|
||||||
|
|
||||||
|
Adapted from https://github.com/PetrochukM/PyTorch-NLP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sampler (torch.data.utils.sampler.Sampler):
|
||||||
|
batch_size (int): Size of mini-batch.
|
||||||
|
drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
|
||||||
|
than `batch_size`.
|
||||||
|
data (list): List of data samples.
|
||||||
|
sort_key (callable, optional): Callable to specify a comparison key for sorting.
|
||||||
|
bucket_size_multiplier (int, optional): Buckets are of size
|
||||||
|
`batch_size * bucket_size_multiplier`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> sampler = WeightedRandomSampler(weights, len(weights))
|
||||||
|
>>> sampler = BucketBatchSampler(sampler, data=data_items, batch_size=32, drop_last=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sampler,
|
||||||
|
data,
|
||||||
|
batch_size,
|
||||||
|
drop_last,
|
||||||
|
sort_key: Union[Callable, List] = identity,
|
||||||
|
bucket_size_multiplier=100,
|
||||||
|
):
|
||||||
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
|
self.data = data
|
||||||
|
self.sort_key = sort_key
|
||||||
|
_bucket_size = batch_size * bucket_size_multiplier
|
||||||
|
if hasattr(sampler, "__len__"):
|
||||||
|
_bucket_size = min(_bucket_size, len(sampler))
|
||||||
|
self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for idxs in self.bucket_sampler:
|
||||||
|
bucket_data = [self.data[idx] for idx in idxs]
|
||||||
|
sorted_sampler = SortedSampler(bucket_data, self.sort_key)
|
||||||
|
for batch_idx in SubsetRandomSampler(list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
|
||||||
|
sorted_idxs = [idxs[i] for i in batch_idx]
|
||||||
|
yield sorted_idxs
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self.drop_last:
|
||||||
|
return len(self.sampler) // self.batch_size
|
||||||
|
return math.ceil(len(self.sampler) / self.batch_size)
|
|
@ -5,11 +5,11 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.encoder.utils.samplers import PerfectBatchSampler
|
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.tts.utils.data import get_length_balancer_weights
|
from TTS.tts.utils.data import get_length_balancer_weights
|
||||||
from TTS.tts.utils.languages import get_language_balancer_weights
|
from TTS.tts.utils.languages import get_language_balancer_weights
|
||||||
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
from TTS.tts.utils.speakers import get_speaker_balancer_weights
|
||||||
|
from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler
|
||||||
|
|
||||||
# Fixing random state to avoid random fails
|
# Fixing random state to avoid random fails
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
@ -163,3 +163,31 @@ class TestSamplers(unittest.TestCase):
|
||||||
else:
|
else:
|
||||||
len2 += 1
|
len2 += 1
|
||||||
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
|
assert is_balanced(len1, len2), "Length Weighted sampler is supposed to be balanced"
|
||||||
|
|
||||||
|
def test_bucket_batch_sampler(self):
|
||||||
|
bucket_size_multiplier = 2
|
||||||
|
sampler = range(len(train_samples))
|
||||||
|
sampler = BucketBatchSampler(
|
||||||
|
sampler,
|
||||||
|
data=train_samples,
|
||||||
|
batch_size=7,
|
||||||
|
drop_last=True,
|
||||||
|
sort_key=lambda x: len(x["text"]),
|
||||||
|
bucket_size_multiplier=bucket_size_multiplier,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if the samples are sorted by text lenght whuile bucketing
|
||||||
|
min_text_len_in_bucket = 0
|
||||||
|
bucket_items = []
|
||||||
|
for batch_idx, batch in enumerate(list(sampler)):
|
||||||
|
if (batch_idx + 1) % bucket_size_multiplier == 0:
|
||||||
|
for bucket_item in bucket_items:
|
||||||
|
self.assertLessEqual(min_text_len_in_bucket, len(train_samples[bucket_item]["text"]))
|
||||||
|
min_text_len_in_bucket = len(train_samples[bucket_item]["text"])
|
||||||
|
min_text_len_in_bucket = 0
|
||||||
|
bucket_items = []
|
||||||
|
else:
|
||||||
|
bucket_items += batch
|
||||||
|
|
||||||
|
# check sampler length
|
||||||
|
self.assertEqual(len(sampler), len(train_samples) // 7)
|
||||||
|
|
Loading…
Reference in New Issue