mirror of https://github.com/coqui-ai/TTS.git
Transform the Speaker Encoder dataset to a generic dataset and create emotion encoder config
This commit is contained in:
parent
1c6d16cffc
commit
854c887764
|
@ -10,7 +10,7 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
from trainer.torch import NoamLR
|
||||
|
||||
from TTS.encoder.dataset import SpeakerEncoderDataset
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model
|
||||
from TTS.encoder.utils.training import init_training
|
||||
|
@ -35,13 +35,13 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
if is_val:
|
||||
loader = None
|
||||
else:
|
||||
dataset = SpeakerEncoderDataset(
|
||||
dataset = EncoderDataset(
|
||||
ap,
|
||||
meta_data_eval if is_val else meta_data_train,
|
||||
voice_len=c.voice_len,
|
||||
num_utter_per_speaker=c.num_utters_per_speaker,
|
||||
num_speakers_in_batch=c.num_speakers_in_batch,
|
||||
skip_speakers=c.skip_speakers,
|
||||
num_utter_per_class=c.num_utter_per_class,
|
||||
num_classes_in_batch=c.num_classes_in_batch,
|
||||
skip_classes=c.skip_classes,
|
||||
storage_size=c.storage["storage_size"],
|
||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||
verbose=verbose,
|
||||
|
@ -52,12 +52,12 @@ def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False
|
|||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.num_speakers_in_batch,
|
||||
batch_size=c.num_classes_in_batch,
|
||||
shuffle=False,
|
||||
num_workers=c.num_loader_workers,
|
||||
collate_fn=dataset.collate_fn,
|
||||
)
|
||||
return loader, dataset.get_num_speakers()
|
||||
return loader, dataset.get_num_classes()
|
||||
|
||||
|
||||
def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
||||
|
@ -91,7 +91,7 @@ def train(model, optimizer, scheduler, criterion, data_loader, global_step):
|
|||
outputs = model(inputs)
|
||||
|
||||
# loss computation
|
||||
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1), labels)
|
||||
loss = criterion(outputs.view(c.num_classes_in_batch, outputs.shape[0] // c.num_classes_in_batch, -1), labels)
|
||||
loss.backward()
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
@ -160,14 +160,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# pylint: disable=redefined-outer-name
|
||||
meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=False)
|
||||
|
||||
data_loader, num_speakers = setup_loader(ap, is_val=False, verbose=True)
|
||||
data_loader, num_classes = setup_loader(ap, is_val=False, verbose=True)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_speakers)
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_classes)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ def register_config(model_name: str) -> Coqpit:
|
|||
"""
|
||||
config_class = None
|
||||
config_name = model_name + "_config"
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.encoder"]
|
||||
for path in paths:
|
||||
try:
|
||||
config_class = find_module(path, config_name)
|
||||
|
|
|
@ -37,9 +37,9 @@
|
|||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"num_utters_per_speaker": 10, //
|
||||
"skip_speakers": false, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
"num_classes_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"num_utter_per_class": 10, //
|
||||
"skip_classes": false, // skip speakers with samples less than "num_utter_per_class"
|
||||
|
||||
"voice_len": 1.6, // number of seconds for each training instance
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
|
@ -42,9 +42,9 @@
|
|||
"steps_plot_stats": 100, // number of steps to plot embeddings.
|
||||
|
||||
// Speakers config
|
||||
"num_speakers_in_batch": 200, // Batch size for training.
|
||||
"num_utters_per_speaker": 2, //
|
||||
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
"num_classes_in_batch": 200, // Batch size for training.
|
||||
"num_utter_per_class": 2, //
|
||||
"skip_classes": true, // skip speakers with samples less than "num_utter_per_class"
|
||||
"voice_len": 2, // number of seconds for each training instance
|
||||
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
|
@ -43,9 +43,9 @@
|
|||
"steps_plot_stats": 100, // number of steps to plot embeddings.
|
||||
|
||||
// Speakers config
|
||||
"num_speakers_in_batch": 200, // Batch size for training.
|
||||
"num_utters_per_speaker": 2, //
|
||||
"skip_speakers": true, // skip speakers with samples less than "num_utters_per_speaker"
|
||||
"num_classes_in_batch": 200, // Batch size for training.
|
||||
"num_utter_per_class": 2, //
|
||||
"skip_classes": true, // skip speakers with samples less than "num_utter_per_class"
|
||||
"voice_len": 2, // number of seconds for each training instance
|
||||
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
|
@ -7,17 +7,17 @@ from torch.utils.data import Dataset
|
|||
from TTS.encoder.utils.generic_utils import AugmentWAV, Storage
|
||||
|
||||
|
||||
class SpeakerEncoderDataset(Dataset):
|
||||
class EncoderDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
meta_data,
|
||||
voice_len=1.6,
|
||||
num_speakers_in_batch=64,
|
||||
num_classes_in_batch=64,
|
||||
storage_size=1,
|
||||
sample_from_storage_p=0.5,
|
||||
num_utter_per_speaker=10,
|
||||
skip_speakers=False,
|
||||
num_utter_per_class=10,
|
||||
skip_classes=False,
|
||||
verbose=False,
|
||||
augmentation_config=None,
|
||||
use_torch_spec=None,
|
||||
|
@ -33,22 +33,23 @@ class SpeakerEncoderDataset(Dataset):
|
|||
self.items = meta_data
|
||||
self.sample_rate = ap.sample_rate
|
||||
self.seq_len = int(voice_len * self.sample_rate)
|
||||
self.num_speakers_in_batch = num_speakers_in_batch
|
||||
self.num_utter_per_speaker = num_utter_per_speaker
|
||||
self.skip_speakers = skip_speakers
|
||||
self.num_classes_in_batch = num_classes_in_batch
|
||||
self.num_utter_per_class = num_utter_per_class
|
||||
self.skip_classes = skip_classes
|
||||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.use_torch_spec = use_torch_spec
|
||||
self.__parse_items()
|
||||
storage_max_size = storage_size * num_speakers_in_batch
|
||||
|
||||
storage_max_size = storage_size * num_classes_in_batch
|
||||
self.storage = Storage(
|
||||
maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch
|
||||
maxsize=storage_max_size, storage_batchs=storage_size, num_classes_in_batch=num_classes_in_batch
|
||||
)
|
||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||
|
||||
speakers_aux = list(self.speakers)
|
||||
speakers_aux.sort()
|
||||
self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)}
|
||||
classes_aux = list(self.classes)
|
||||
classes_aux.sort()
|
||||
self.classname_to_classid = {key: i for i, key in enumerate(classes_aux)}
|
||||
|
||||
# Augmentation
|
||||
self.augmentator = None
|
||||
|
@ -63,156 +64,158 @@ class SpeakerEncoderDataset(Dataset):
|
|||
|
||||
if self.verbose:
|
||||
print("\n > DataLoader initialization")
|
||||
print(f" | > Speakers per Batch: {num_speakers_in_batch}")
|
||||
print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_speaker} utters")
|
||||
print(f" | > Classes per Batch: {num_classes_in_batch}")
|
||||
print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_class} utters")
|
||||
print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}")
|
||||
print(f" | > Number of instances : {len(self.items)}")
|
||||
print(f" | > Sequence length: {self.seq_len}")
|
||||
print(f" | > Num speakers: {len(self.speakers)}")
|
||||
print(f" | > Num Classes: {len(self.classes)}")
|
||||
print(f" | > Classes: {list(self.classes)}")
|
||||
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||
return audio
|
||||
|
||||
def __parse_items(self):
|
||||
self.speaker_to_utters = {}
|
||||
self.class_to_utters = {}
|
||||
for i in self.items:
|
||||
path_ = i["audio_file"]
|
||||
speaker_ = i["speaker_name"]
|
||||
if speaker_ in self.speaker_to_utters.keys():
|
||||
self.speaker_to_utters[speaker_].append(path_)
|
||||
else:
|
||||
self.speaker_to_utters[speaker_] = [
|
||||
self.class_to_utters[class_name] = [
|
||||
path_,
|
||||
]
|
||||
|
||||
if self.skip_speakers:
|
||||
self.speaker_to_utters = {
|
||||
k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker
|
||||
if self.skip_classes:
|
||||
self.class_to_utters = {
|
||||
k: v for (k, v) in self.class_to_utters.items() if len(v) >= self.num_utter_per_class
|
||||
}
|
||||
|
||||
self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
|
||||
self.classes = [k for (k, v) in self.class_to_utters.items()]
|
||||
|
||||
def __len__(self):
|
||||
return int(1e10)
|
||||
|
||||
def get_num_speakers(self):
|
||||
return len(self.speakers)
|
||||
def get_num_classes(self):
|
||||
return len(self.classes)
|
||||
|
||||
def __sample_speaker(self, ignore_speakers=None):
|
||||
speaker = random.sample(self.speakers, 1)[0]
|
||||
# if list of speakers_id is provide make sure that it's will be ignored
|
||||
if ignore_speakers and self.speakerid_to_classid[speaker] in ignore_speakers:
|
||||
def __sample_class(self, ignore_classes=None):
|
||||
class_name = random.sample(self.classes, 1)[0]
|
||||
# if list of classes_id is provide make sure that it's will be ignored
|
||||
if ignore_classes and self.classname_to_classid[class_name] in ignore_classes:
|
||||
while True:
|
||||
speaker = random.sample(self.speakers, 1)[0]
|
||||
if self.speakerid_to_classid[speaker] not in ignore_speakers:
|
||||
class_name = random.sample(self.classes, 1)[0]
|
||||
if self.classname_to_classid[class_name] not in ignore_classes:
|
||||
break
|
||||
|
||||
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
||||
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
|
||||
if self.num_utter_per_class > len(self.class_to_utters[class_name]):
|
||||
utters = random.choices(self.class_to_utters[class_name], k=self.num_utter_per_class)
|
||||
else:
|
||||
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
|
||||
return speaker, utters
|
||||
utters = random.sample(self.class_to_utters[class_name], self.num_utter_per_class)
|
||||
return class_name, utters
|
||||
|
||||
def __sample_speaker_utterances(self, speaker):
|
||||
def __sample_class_utterances(self, class_name):
|
||||
"""
|
||||
Sample all M utterances for the given speaker.
|
||||
Sample all M utterances for the given class_name.
|
||||
"""
|
||||
wavs = []
|
||||
labels = []
|
||||
for _ in range(self.num_utter_per_speaker):
|
||||
for _ in range(self.num_utter_per_class):
|
||||
# TODO:dummy but works
|
||||
while True:
|
||||
# remove speakers that have num_utter less than 2
|
||||
if len(self.speaker_to_utters[speaker]) > 1:
|
||||
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
|
||||
# remove classes that have num_utter less than 2
|
||||
if len(self.class_to_utters[class_name]) > 1:
|
||||
utter = random.sample(self.class_to_utters[class_name], 1)[0]
|
||||
else:
|
||||
if speaker in self.speakers:
|
||||
self.speakers.remove(speaker)
|
||||
if class_name in self.classes:
|
||||
self.classes.remove(class_name)
|
||||
|
||||
speaker, _ = self.__sample_speaker()
|
||||
class_name, _ = self.__sample_class()
|
||||
continue
|
||||
|
||||
wav = self.load_wav(utter)
|
||||
if wav.shape[0] - self.seq_len > 0:
|
||||
break
|
||||
|
||||
if utter in self.speaker_to_utters[speaker]:
|
||||
self.speaker_to_utters[speaker].remove(utter)
|
||||
if utter in self.class_to_utters[class_name]:
|
||||
self.class_to_utters[class_name].remove(utter)
|
||||
|
||||
if self.augmentator is not None and self.data_augmentation_p:
|
||||
if random.random() < self.data_augmentation_p:
|
||||
wav = self.augmentator.apply_one(wav)
|
||||
|
||||
wavs.append(wav)
|
||||
labels.append(self.speakerid_to_classid[speaker])
|
||||
labels.append(self.classname_to_classid[class_name])
|
||||
return wavs, labels
|
||||
|
||||
def __getitem__(self, idx):
|
||||
speaker, _ = self.__sample_speaker()
|
||||
speaker_id = self.speakerid_to_classid[speaker]
|
||||
return speaker, speaker_id
|
||||
class_name, _ = self.__sample_class()
|
||||
class_id = self.classname_to_classid[class_name]
|
||||
return class_name, class_id
|
||||
|
||||
def __load_from_disk_and_storage(self, speaker):
|
||||
def __load_from_disk_and_storage(self, class_name):
|
||||
# don't sample from storage, but from HDD
|
||||
wavs_, labels_ = self.__sample_speaker_utterances(speaker)
|
||||
wavs_, labels_ = self.__sample_class_utterances(class_name)
|
||||
# put the newly loaded item into storage
|
||||
self.storage.append((wavs_, labels_))
|
||||
return wavs_, labels_
|
||||
|
||||
def collate_fn(self, batch):
|
||||
# get the batch speaker_ids
|
||||
# get the batch class_ids
|
||||
batch = np.array(batch)
|
||||
speakers_id_in_batch = set(batch[:, 1].astype(np.int32))
|
||||
classes_id_in_batch = set(batch[:, 1].astype(np.int32))
|
||||
|
||||
labels = []
|
||||
feats = []
|
||||
speakers = set()
|
||||
classes = set()
|
||||
|
||||
for speaker, speaker_id in batch:
|
||||
speaker_id = int(speaker_id)
|
||||
for class_name, class_id in batch:
|
||||
class_id = int(class_id)
|
||||
|
||||
# ensure that an speaker appears only once in the batch
|
||||
if speaker_id in speakers:
|
||||
# ensure that an class appears only once in the batch
|
||||
if class_id in classes:
|
||||
|
||||
# remove current speaker
|
||||
if speaker_id in speakers_id_in_batch:
|
||||
speakers_id_in_batch.remove(speaker_id)
|
||||
# remove current class
|
||||
if class_id in classes_id_in_batch:
|
||||
classes_id_in_batch.remove(class_id)
|
||||
|
||||
speaker, _ = self.__sample_speaker(ignore_speakers=speakers_id_in_batch)
|
||||
speaker_id = self.speakerid_to_classid[speaker]
|
||||
speakers_id_in_batch.add(speaker_id)
|
||||
class_name, _ = self.__sample_class(ignore_classes=classes_id_in_batch)
|
||||
class_id = self.classname_to_classid[class_name]
|
||||
classes_id_in_batch.add(class_id)
|
||||
|
||||
if random.random() < self.sample_from_storage_p and self.storage.full():
|
||||
# sample from storage (if full)
|
||||
wavs_, labels_ = self.storage.get_random_sample_fast()
|
||||
|
||||
# force choose the current speaker or other not in batch
|
||||
# force choose the current class or other not in batch
|
||||
# It's necessary for ideal training with AngleProto and GE2E losses
|
||||
if labels_[0] in speakers_id_in_batch and labels_[0] != speaker_id:
|
||||
if labels_[0] in classes_id_in_batch and labels_[0] != class_id:
|
||||
attempts = 0
|
||||
while True:
|
||||
wavs_, labels_ = self.storage.get_random_sample_fast()
|
||||
if labels_[0] == speaker_id or labels_[0] not in speakers_id_in_batch:
|
||||
if labels_[0] == class_id or labels_[0] not in classes_id_in_batch:
|
||||
break
|
||||
|
||||
attempts += 1
|
||||
# Try 5 times after that load from disk
|
||||
if attempts >= 5:
|
||||
wavs_, labels_ = self.__load_from_disk_and_storage(speaker)
|
||||
wavs_, labels_ = self.__load_from_disk_and_storage(class_name)
|
||||
break
|
||||
else:
|
||||
# don't sample from storage, but from HDD
|
||||
wavs_, labels_ = self.__load_from_disk_and_storage(speaker)
|
||||
wavs_, labels_ = self.__load_from_disk_and_storage(class_name)
|
||||
|
||||
# append speaker for control
|
||||
speakers.add(labels_[0])
|
||||
# append class for control
|
||||
classes.add(labels_[0])
|
||||
|
||||
# remove current speaker and append other
|
||||
if speaker_id in speakers_id_in_batch:
|
||||
speakers_id_in_batch.remove(speaker_id)
|
||||
# remove current class and append other
|
||||
if class_id in classes_id_in_batch:
|
||||
classes_id_in_batch.remove(class_id)
|
||||
|
||||
speakers_id_in_batch.add(labels_[0])
|
||||
classes_id_in_batch.add(labels_[0])
|
||||
|
||||
# get a random subset of each of the wavs and extract mel spectrograms.
|
||||
feats_ = []
|
|
@ -0,0 +1,17 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
|
||||
from TTS.encoder.speaker_encoder_config import SpeakerEncoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmotionEncoderConfig(SpeakerEncoderConfig):
|
||||
"""Defines parameters for Speaker Encoder model."""
|
||||
|
||||
model: str = "emotion_encoder"
|
||||
|
||||
def check_values(self):
|
||||
super().check_values()
|
||||
c = asdict(self)
|
||||
assert (
|
||||
c["model_params"]["input_dim"] == self.audio.num_mels
|
||||
), " [!] model input dimendion must be equal to melspectrogram dimension."
|
|
@ -51,10 +51,10 @@ class SpeakerEncoderConfig(BaseTrainingConfig):
|
|||
print_step: int = 20
|
||||
|
||||
# data loader
|
||||
num_speakers_in_batch: int = MISSING
|
||||
num_utters_per_speaker: int = MISSING
|
||||
num_classes_in_batch: int = MISSING
|
||||
num_utter_per_class: int = MISSING
|
||||
num_loader_workers: int = MISSING
|
||||
skip_speakers: bool = False
|
||||
skip_classes: bool = False
|
||||
voice_len: float = 1.6
|
||||
|
||||
def check_values(self):
|
|
@ -14,11 +14,11 @@ from TTS.utils.io import save_fsspec
|
|||
|
||||
|
||||
class Storage(object):
|
||||
def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8):
|
||||
def __init__(self, maxsize, storage_batchs, num_classes_in_batch, num_threads=8):
|
||||
# use multiprocessing for threading safe
|
||||
self.storage = Manager().list()
|
||||
self.maxsize = maxsize
|
||||
self.num_speakers_in_batch = num_speakers_in_batch
|
||||
self.num_classes_in_batch = num_classes_in_batch
|
||||
self.num_threads = num_threads
|
||||
self.ignore_last_batch = False
|
||||
|
||||
|
@ -28,7 +28,7 @@ class Storage(object):
|
|||
# used for fast random sample
|
||||
self.safe_storage_size = self.maxsize - self.num_threads
|
||||
if self.ignore_last_batch:
|
||||
self.safe_storage_size -= self.num_speakers_in_batch
|
||||
self.safe_storage_size -= self.num_classes_in_batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.storage)
|
||||
|
@ -48,7 +48,7 @@ class Storage(object):
|
|||
storage_size = len(self.storage) - self.num_threads
|
||||
|
||||
if self.ignore_last_batch:
|
||||
storage_size -= self.num_speakers_in_batch
|
||||
storage_size -= self.num_classes_in_batch
|
||||
|
||||
return self.storage[random.randint(0, storage_size)]
|
||||
|
|
@ -29,12 +29,12 @@ colormap = (
|
|||
)
|
||||
|
||||
|
||||
def plot_embeddings(embeddings, num_utter_per_speaker):
|
||||
embeddings = embeddings[: 10 * num_utter_per_speaker]
|
||||
def plot_embeddings(embeddings, num_utter_per_class):
|
||||
embeddings = embeddings[: 10 * num_utter_per_class]
|
||||
model = umap.UMAP()
|
||||
projection = model.fit_transform(embeddings)
|
||||
num_speakers = embeddings.shape[0] // num_utter_per_speaker
|
||||
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker)
|
||||
num_speakers = embeddings.shape[0] // num_utter_per_class
|
||||
ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_class)
|
||||
colors = [colormap[i] for i in ground_truth]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(16, 10))
|
Binary file not shown.
Before Width: | Height: | Size: 24 KiB |
|
@ -435,7 +435,7 @@ def emotion(root_path, meta_file, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append([wav_file, speaker_id, emotion_id])
|
||||
items.append([speaker_id, wav_file, emotion_id])
|
||||
return items
|
||||
|
||||
|
||||
|
|
|
@ -24,8 +24,8 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
|||
|
||||
config = SpeakerEncoderConfig(
|
||||
batch_size=4,
|
||||
num_speakers_in_batch=1,
|
||||
num_utters_per_speaker=10,
|
||||
num_classes_in_batch=1,
|
||||
num_utter_per_class=10,
|
||||
num_loader_workers=0,
|
||||
max_train_step=2,
|
||||
print_step=1,
|
||||
|
|
|
@ -36,8 +36,8 @@
|
|||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"steps_plot_stats": 10, // number of steps to plot embeddings.
|
||||
"num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"num_utters_per_speaker": 10, //
|
||||
"num_classes_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"num_utter_per_class": 10, //
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
|
|
Loading…
Reference in New Issue