Transform the Speaker Encoder dataset to a generic dataset and create emotion encoder config

This commit is contained in:
Edresson Casanova 2022-03-01 09:09:37 -03:00
parent 1c6d16cffc
commit 854c887764
24 changed files with 130 additions and 110 deletions

View File

@ -10,7 +10,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from trainer.torch import NoamLR from trainer.torch import NoamLR
from TTS.encoder.dataset import SpeakerEncoderDataset from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
from TTS.encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model from TTS.encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
from TTS.encoder.utils.training import init_training from TTS.encoder.utils.training import init_training
@ -35,13 +35,13 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
if is_val: if is_val:
loader = None loader = None
else: else:
dataset = SpeakerEncoderDataset( dataset = EncoderDataset(
ap, ap,
meta_data_eval if is_val else meta_data_train, meta_data_eval if is_val else meta_data_train,
voice_len=c.voice_len, voice_len=c.voice_len,
num_utter_per_speaker=c.num_utters_per_speaker, num_utter_per_class=c.num_utter_per_class,
num_speakers_in_batch=c.num_speakers_in_batch, num_classes_in_batch=c.num_classes_in_batch,
skip_speakers=c.skip_speakers, skip_classes=c.skip_classes,
storage_size=c.storage["storage_size"], storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"], sample_from_storage_p=c.storage["sample_from_storage_p"],
verbose=verbose, verbose=verbose,
@ -52,12 +52,12 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader( loader = DataLoader(
dataset, dataset,
batch_size=c.num_speakers_in_batch, batch_size=c.num_classes_in_batch,
shuffle=False, shuffle=False,
num_workers=c.num_loader_workers, num_workers=c.num_loader_workers,
collate_fn=dataset.collate_fn, collate_fn=dataset.collate_fn,
) )
return loader, dataset.get_num_speakers() return loader, dataset.get_num_classes()
def train(model, optimizer, scheduler, criterion, data_loader, global_step): def train(model, optimizer, scheduler, criterion, data_loader, global_step):
@ -91,7 +91,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
outputs = model(inputs) outputs = model(inputs)
# loss computation # loss computation
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels) loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels)
loss.backward() loss.backward()
grad_norm, _ = check_update(model, c.grad_clip) grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step() optimizer.step()
@ -160,14 +160,14 @@ def main(args): # pylint: disable=redefined-outer-name
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False) meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True) data_loader, num_classes = setup_loader(ap, is_val=False, verbose=True)
if c.loss == "ge2e": if c.loss == "ge2e":
criterion = GE2ELoss(loss_method="softmax") criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto": elif c.loss == "angleproto":
criterion = AngleProtoLoss() criterion = AngleProtoLoss()
elif c.loss == "softmaxproto": elif c.loss == "softmaxproto":
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_speakers) criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
else: else:
raise Exception("The %s not is a loss supported" % c.loss) raise Exception("The %s not is a loss supported" % c.loss)

View File

@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit:
""" """
config_class = None config_class = None
config_name = model_name + "_config" config_name = model_name + "_config"
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"] paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder"]
for path in paths: for path in paths:
try: try:
config_class = find_module(path, config_name) config_class = find_module(path, config_name)

View File

@ -37,9 +37,9 @@
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 10, // number of steps to plot embeddings. "steps_plot_stats": 10, // number of steps to plot embeddings.
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "num_classes_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"num_utters_per_speaker": 10, // "num_utter_per_class": 10, //
"skip_speakers": false, // skip speakers with samples less than "num_utters_per_speaker" "skip_classes": false, // skip speakers with samples less than "num_utter_per_class"
"voice_len": 1.6, // number of seconds for each training instance "voice_len": 1.6, // number of seconds for each training instance
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.

View File

@ -42,9 +42,9 @@
"steps_plot_stats": 100, // number of steps to plot embeddings. "steps_plot_stats": 100, // number of steps to plot embeddings.
// Speakers config // Speakers config
"num_speakers_in_batch": 200, // Batch size for training. "num_classes_in_batch": 200, // Batch size for training.
"num_utters_per_speaker": 2, // "num_utter_per_class": 2, //
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker" "skip_classes": true, // skip speakers with samples less than "num_utter_per_class"
"voice_len": 2, // number of seconds for each training instance "voice_len": 2, // number of seconds for each training instance
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.

View File

@ -43,9 +43,9 @@
"steps_plot_stats": 100, // number of steps to plot embeddings. "steps_plot_stats": 100, // number of steps to plot embeddings.
// Speakers config // Speakers config
"num_speakers_in_batch": 200, // Batch size for training. "num_classes_in_batch": 200, // Batch size for training.
"num_utters_per_speaker": 2, // "num_utter_per_class": 2, //
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker" "skip_classes": true, // skip speakers with samples less than "num_utter_per_class"
"voice_len": 2, // number of seconds for each training instance "voice_len": 2, // number of seconds for each training instance
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.

View File

@ -7,17 +7,17 @@ from torch.utils.data import Dataset
from TTS.encoder.utils.generic_utils import AugmentWAV, Storage from TTS.encoder.utils.generic_utils import AugmentWAV, Storage
class SpeakerEncoderDataset(Dataset): class EncoderDataset(Dataset):
def __init__( def __init__(
self, self,
ap, ap,
meta_data, meta_data,
voice_len=1.6, voice_len=1.6,
num_speakers_in_batch=64, num_classes_in_batch=64,
storage_size=1, storage_size=1,
sample_from_storage_p=0.5, sample_from_storage_p=0.5,
num_utter_per_speaker=10, num_utter_per_class=10,
skip_speakers=False, skip_classes=False,
verbose=False, verbose=False,
augmentation_config=None, augmentation_config=None,
use_torch_spec=None, use_torch_spec=None,
@ -33,22 +33,23 @@ class SpeakerEncoderDataset(Dataset):
self.items = meta_data self.items = meta_data
self.sample_rate = ap.sample_rate self.sample_rate = ap.sample_rate
self.seq_len = int(voice_len * self.sample_rate) self.seq_len = int(voice_len * self.sample_rate)
self.num_speakers_in_batch = num_speakers_in_batch self.num_classes_in_batch = num_classes_in_batch
self.num_utter_per_speaker = num_utter_per_speaker self.num_utter_per_class = num_utter_per_class
self.skip_speakers = skip_speakers self.skip_classes = skip_classes
self.ap = ap self.ap = ap
self.verbose = verbose self.verbose = verbose
self.use_torch_spec = use_torch_spec self.use_torch_spec = use_torch_spec
self.__parse_items() self.__parse_items()
storage_max_size = storage_size * num_speakers_in_batch
storage_max_size = storage_size * num_classes_in_batch
self.storage = Storage( self.storage = Storage(
maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch
) )
self.sample_from_storage_p = float(sample_from_storage_p) self.sample_from_storage_p = float(sample_from_storage_p)
speakers_aux = list(self.speakers) classes_aux = list(self.classes)
speakers_aux.sort() classes_aux.sort()
self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)} self.classname_to_classid = {key: i for i, key in enumerate(classes_aux)}
# Augmentation # Augmentation
self.augmentator = None self.augmentator = None
@ -63,156 +64,158 @@ class SpeakerEncoderDataset(Dataset):
if self.verbose: if self.verbose:
print("\n > DataLoader initialization") print("\n > DataLoader initialization")
print(f" | > Speakers per Batch: {num_speakers_in_batch}") print(f" | > Classes per Batch: {num_classes_in_batch}")
print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_speaker} utters") print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_class} utters")
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}") print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
print(f" | > Number of instances : {len(self.items)}") print(f" | > Number of instances : {len(self.items)}")
print(f" | > Sequence length: {self.seq_len}") print(f" | > Sequence length: {self.seq_len}")
print(f" | > Num speakers: {len(self.speakers)}") print(f" | > Num Classes: {len(self.classes)}")
print(f" | > Classes: {list(self.classes)}")
def load_wav(self, filename): def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
return audio return audio
def __parse_items(self): def __parse_items(self):
self.speaker_to_utters = {} self.class_to_utters = {}
for i in self.items: for i in self.items:
path_ = i["audio_file"] path_ = i["audio_file"]
speaker_ = i["speaker_name"] speaker_ = i["speaker_name"]
if speaker_ in self.speaker_to_utters.keys(): if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_) self.speaker_to_utters[speaker_].append(path_)
else: else:
self.speaker_to_utters[speaker_] = [ self.class_to_utters[class_name] = [
path_, path_,
] ]
if self.skip_speakers: if self.skip_classes:
self.speaker_to_utters = { self.class_to_utters = {
k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker k: v for (k, v) in self.class_to_utters.items() if len(v) >= self.num_utter_per_class
} }
self.speakers = [k for (k, v) in self.speaker_to_utters.items()] self.classes = [k for (k, v) in self.class_to_utters.items()]
def __len__(self): def __len__(self):
return int(1e10) return int(1e10)
def get_num_speakers(self): def get_num_classes(self):
return len(self.speakers) return len(self.classes)
def __sample_speaker(self, ignore_speakers=None): def __sample_class(self, ignore_classes=None):
speaker = random.sample(self.speakers, 1)[0] class_name = random.sample(self.classes, 1)[0]
# if list of speakers_id is provide make sure that it's will be ignored # if list of classes_id is provide make sure that it's will be ignored
if ignore_speakers and self.speakerid_to_classid[speaker] in ignore_speakers: if ignore_classes and self.classname_to_classid[class_name] in ignore_classes:
while True: while True:
speaker = random.sample(self.speakers, 1)[0] class_name = random.sample(self.classes, 1)[0]
if self.speakerid_to_classid[speaker] not in ignore_speakers: if self.classname_to_classid[class_name] not in ignore_classes:
break break
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): if self.num_utter_per_class > len(self.class_to_utters[class_name]):
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker) utters = random.choices(self.class_to_utters[class_name], k=self.num_utter_per_class)
else: else:
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker) utters = random.sample(self.class_to_utters[class_name], self.num_utter_per_class)
return speaker, utters return class_name, utters
def __sample_speaker_utterances(self, speaker): def __sample_class_utterances(self, class_name):
""" """
Sample all M utterances for the given speaker. Sample all M utterances for the given class_name.
""" """
wavs = [] wavs = []
labels = [] labels = []
for _ in range(self.num_utter_per_speaker): for _ in range(self.num_utter_per_class):
# TODO:dummy but works # TODO:dummy but works
while True: while True:
# remove speakers that have num_utter less than 2 # remove classes that have num_utter less than 2
if len(self.speaker_to_utters[speaker]) > 1: if len(self.class_to_utters[class_name]) > 1:
utter = random.sample(self.speaker_to_utters[speaker], 1)[0] utter = random.sample(self.class_to_utters[class_name], 1)[0]
else: else:
if speaker in self.speakers: if class_name in self.classes:
self.speakers.remove(speaker) self.classes.remove(class_name)
speaker, _ = self.__sample_speaker() class_name, _ = self.__sample_class()
continue continue
wav = self.load_wav(utter) wav = self.load_wav(utter)
if wav.shape[0] - self.seq_len > 0: if wav.shape[0] - self.seq_len > 0:
break break
if utter in self.speaker_to_utters[speaker]: if utter in self.class_to_utters[class_name]:
self.speaker_to_utters[speaker].remove(utter) self.class_to_utters[class_name].remove(utter)
if self.augmentator is not None and self.data_augmentation_p: if self.augmentator is not None and self.data_augmentation_p:
if random.random() < self.data_augmentation_p: if random.random() < self.data_augmentation_p:
wav = self.augmentator.apply_one(wav) wav = self.augmentator.apply_one(wav)
wavs.append(wav) wavs.append(wav)
labels.append(self.speakerid_to_classid[speaker]) labels.append(self.classname_to_classid[class_name])
return wavs, labels return wavs, labels
def __getitem__(self, idx): def __getitem__(self, idx):
speaker, _ = self.__sample_speaker() class_name, _ = self.__sample_class()
speaker_id = self.speakerid_to_classid[speaker] class_id = self.classname_to_classid[class_name]
return speaker, speaker_id return class_name, class_id
def __load_from_disk_and_storage(self, speaker): def __load_from_disk_and_storage(self, class_name):
# don't sample from storage, but from HDD # don't sample from storage, but from HDD
wavs_, labels_ = self.__sample_speaker_utterances(speaker) wavs_, labels_ = self.__sample_class_utterances(class_name)
# put the newly loaded item into storage # put the newly loaded item into storage
self.storage.append((wavs_, labels_)) self.storage.append((wavs_, labels_))
return wavs_, labels_ return wavs_, labels_
def collate_fn(self, batch): def collate_fn(self, batch):
# get the batch speaker_ids # get the batch class_ids
batch = np.array(batch) batch = np.array(batch)
speakers_id_in_batch = set(batch[:, 1].astype(np.int32)) classes_id_in_batch = set(batch[:, 1].astype(np.int32))
labels = [] labels = []
feats = [] feats = []
speakers = set() classes = set()
for speaker, speaker_id in batch: for class_name, class_id in batch:
speaker_id = int(speaker_id) class_id = int(class_id)
# ensure that an speaker appears only once in the batch # ensure that an class appears only once in the batch
if speaker_id in speakers: if class_id in classes:
# remove current speaker # remove current class
if speaker_id in speakers_id_in_batch: if class_id in classes_id_in_batch:
speakers_id_in_batch.remove(speaker_id) classes_id_in_batch.remove(class_id)
speaker, _ = self.__sample_speaker(ignore_speakers=speakers_id_in_batch) class_name, _ = self.__sample_class(ignore_classes=classes_id_in_batch)
speaker_id = self.speakerid_to_classid[speaker] class_id = self.classname_to_classid[class_name]
speakers_id_in_batch.add(speaker_id) classes_id_in_batch.add(class_id)
if random.random() < self.sample_from_storage_p and self.storage.full(): if random.random() < self.sample_from_storage_p and self.storage.full():
# sample from storage (if full) # sample from storage (if full)
wavs_, labels_ = self.storage.get_random_sample_fast() wavs_, labels_ = self.storage.get_random_sample_fast()
# force choose the current speaker or other not in batch # force choose the current class or other not in batch
# It's necessary for ideal training with AngleProto and GE2E losses # It's necessary for ideal training with AngleProto and GE2E losses
if labels_[0] in speakers_id_in_batch and labels_[0] != speaker_id: if labels_[0] in classes_id_in_batch and labels_[0] != class_id:
attempts = 0 attempts = 0
while True: while True:
wavs_, labels_ = self.storage.get_random_sample_fast() wavs_, labels_ = self.storage.get_random_sample_fast()
if labels_[0] == speaker_id or labels_[0] not in speakers_id_in_batch: if labels_[0] == class_id or labels_[0] not in classes_id_in_batch:
break break
attempts += 1 attempts += 1
# Try 5 times after that load from disk # Try 5 times after that load from disk
if attempts >= 5: if attempts >= 5:
wavs_, labels_ = self.__load_from_disk_and_storage(speaker) wavs_, labels_ = self.__load_from_disk_and_storage(class_name)
break break
else: else:
# don't sample from storage, but from HDD # don't sample from storage, but from HDD
wavs_, labels_ = self.__load_from_disk_and_storage(speaker) wavs_, labels_ = self.__load_from_disk_and_storage(class_name)
# append speaker for control # append class for control
speakers.add(labels_[0]) classes.add(labels_[0])
# remove current speaker and append other # remove current class and append other
if speaker_id in speakers_id_in_batch: if class_id in classes_id_in_batch:
speakers_id_in_batch.remove(speaker_id) classes_id_in_batch.remove(class_id)
speakers_id_in_batch.add(labels_[0]) classes_id_in_batch.add(labels_[0])
# get a random subset of each of the wavs and extract mel spectrograms. # get a random subset of each of the wavs and extract mel spectrograms.
feats_ = [] feats_ = []

View File

@ -0,0 +1,17 @@
from dataclasses import asdict, dataclass
from TTS.encoder.speaker_encoder_config import SpeakerEncoderConfig
@dataclass
class EmotionEncoderConfig(SpeakerEncoderConfig):
"""Defines parameters for Speaker Encoder model."""
model: str = "emotion_encoder"
def check_values(self):
super().check_values()
c = asdict(self)
assert (
c["model_params"]["input_dim"] == self.audio.num_mels
), " [!] model input dimendion must be equal to melspectrogram dimension."

View File

@ -51,10 +51,10 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
print_step: int = 20 print_step: int = 20
# data loader # data loader
num_speakers_in_batch: int = MISSING num_classes_in_batch: int = MISSING
num_utters_per_speaker: int = MISSING num_utter_per_class: int = MISSING
num_loader_workers: int = MISSING num_loader_workers: int = MISSING
skip_speakers: bool = False skip_classes: bool = False
voice_len: float = 1.6 voice_len: float = 1.6
def check_values(self): def check_values(self):

View File

@ -14,11 +14,11 @@ from TTS.utils.io import save_fsspec
class Storage(object): class Storage(object):
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8): def __init__(self, maxsize, storage_batchs, num_classes_in_batch, num_threads=8):
# use multiprocessing for threading safe # use multiprocessing for threading safe
self.storage = Manager().list() self.storage = Manager().list()
self.maxsize = maxsize self.maxsize = maxsize
self.num_speakers_in_batch = num_speakers_in_batch self.num_classes_in_batch = num_classes_in_batch
self.num_threads = num_threads self.num_threads = num_threads
self.ignore_last_batch = False self.ignore_last_batch = False
@ -28,7 +28,7 @@ class Storage(object):
# used for fast random sample # used for fast random sample
self.safe_storage_size = self.maxsize - self.num_threads self.safe_storage_size = self.maxsize - self.num_threads
if self.ignore_last_batch: if self.ignore_last_batch:
self.safe_storage_size -= self.num_speakers_in_batch self.safe_storage_size -= self.num_classes_in_batch
def __len__(self): def __len__(self):
return len(self.storage) return len(self.storage)
@ -48,7 +48,7 @@ class Storage(object):
storage_size = len(self.storage) - self.num_threads storage_size = len(self.storage) - self.num_threads
if self.ignore_last_batch: if self.ignore_last_batch:
storage_size -= self.num_speakers_in_batch storage_size -= self.num_classes_in_batch
return self.storage[random.randint(0, storage_size)] return self.storage[random.randint(0, storage_size)]

View File

@ -29,12 +29,12 @@ colormap = (
) )
def plot_embeddings(embeddings, num_utter_per_speaker): def plot_embeddings(embeddings, num_utter_per_class):
embeddings = embeddings[: 10 * num_utter_per_speaker] embeddings = embeddings[: 10 * num_utter_per_class]
model = umap.UMAP() model = umap.UMAP()
projection = model.fit_transform(embeddings) projection = model.fit_transform(embeddings)
num_speakers = embeddings.shape[0] // num_utter_per_speaker num_speakers = embeddings.shape[0] // num_utter_per_class
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker) ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_class)
colors = [colormap[i] for i in ground_truth] colors = [colormap[i] for i in ground_truth]
fig, ax = plt.subplots(figsize=(16, 10)) fig, ax = plt.subplots(figsize=(16, 10))

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

View File

@ -435,7 +435,7 @@ def emotion(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers: if speaker_id in ignored_speakers:
continue continue
items.append([wav_file, speaker_id, emotion_id]) items.append([speaker_id, wav_file, emotion_id])
return items return items

View File

@ -24,8 +24,8 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = SpeakerEncoderConfig( config = SpeakerEncoderConfig(
batch_size=4, batch_size=4,
num_speakers_in_batch=1, num_classes_in_batch=1,
num_utters_per_speaker=10, num_utter_per_class=10,
num_loader_workers=0, num_loader_workers=0,
max_train_step=2, max_train_step=2,
print_step=1, print_step=1,

View File

@ -36,8 +36,8 @@
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"steps_plot_stats": 10, // number of steps to plot embeddings. "steps_plot_stats": 10, // number of steps to plot embeddings.
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "num_classes_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"num_utters_per_speaker": 10, // "num_utter_per_class": 10, //
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
"wd": 0.000001, // Weight decay weight. "wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step" "checkpoint": true, // If true, it saves checkpoints per "save_step"