Fix seed in test_samplers to avoid random fails

This commit is contained in:
WeberJulian 2021-12-01 23:48:38 +01:00 committed by Eren Gölge
parent 6f01eed672
commit 8b3769c957
1 changed files with 12 additions and 10 deletions

View File

@ -1,10 +1,12 @@
from torch.utils.data import RandomSampler
from TTS.tts.datasets import load_tts_samples
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.utils.languages import get_language_weighted_sampler
import torch
import functools
# Fixing random state to avoid random fails
torch.manual_seed(0)
dataset_config_en = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
@ -23,18 +25,18 @@ dataset_config_pt = BaseDatasetConfig(
# Adding the EN samples twice to create an unbalanced dataset
train_samples, eval_samples = load_tts_samples(
[dataset_config_en, dataset_config_en, dataset_config_pt],
[dataset_config_en, dataset_config_en, dataset_config_pt],
eval_split=True
)
def is_balanced(lang_1, lang_2):
return 0.9 < lang_1/lang_2 < 1.1
return 0.85 < lang_1/lang_2 < 1.2
random_sampler = RandomSampler(train_samples)
random_sampler = torch.utils.data.RandomSampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
en, pt = 0, 0
for id in ids:
if train_samples[id][3] == 'en':
for index in ids:
if train_samples[index][3] == 'en':
en += 1
else:
pt += 1
@ -44,10 +46,10 @@ assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
weighted_sampler = get_language_weighted_sampler(train_samples)
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
en, pt = 0, 0
for id in ids:
if train_samples[id][3] == 'en':
for index in ids:
if train_samples[index][3] == 'en':
en += 1
else:
pt += 1
assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"
assert is_balanced(en, pt), "Weighted sampler is supposed to be balanced"