mirror of https://github.com/coqui-ai/TTS.git
Fix seed in test_samplers to avoid random fails
This commit is contained in:
parent
22c7be5f44
commit
7b81c16434
|
@ -1,10 +1,12 @@
|
||||||
from torch.utils.data import RandomSampler
|
|
||||||
from TTS.tts.datasets import load_tts_samples
|
from TTS.tts.datasets import load_tts_samples
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
from TTS.tts.utils.languages import get_language_weighted_sampler
|
from TTS.tts.utils.languages import get_language_weighted_sampler
|
||||||
|
import torch
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
# Fixing random state to avoid random fails
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
dataset_config_en = BaseDatasetConfig(
|
dataset_config_en = BaseDatasetConfig(
|
||||||
name="ljspeech",
|
name="ljspeech",
|
||||||
meta_file_train="metadata.csv",
|
meta_file_train="metadata.csv",
|
||||||
|
@ -28,13 +30,13 @@ train_samples, eval_samples = load_tts_samples(
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_balanced(lang_1, lang_2):
|
def is_balanced(lang_1, lang_2):
|
||||||
return 0.9 < lang_1/lang_2 < 1.1
|
return 0.85 < lang_1/lang_2 < 1.2
|
||||||
|
|
||||||
random_sampler = RandomSampler(train_samples)
|
random_sampler = torch.utils.data.RandomSampler(train_samples)
|
||||||
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
ids = functools.reduce(lambda a, b: a + b, [list(random_sampler) for i in range(100)])
|
||||||
en, pt = 0, 0
|
en, pt = 0, 0
|
||||||
for id in ids:
|
for index in ids:
|
||||||
if train_samples[id][3] == 'en':
|
if train_samples[index][3] == 'en':
|
||||||
en += 1
|
en += 1
|
||||||
else:
|
else:
|
||||||
pt += 1
|
pt += 1
|
||||||
|
@ -44,8 +46,8 @@ assert not is_balanced(en, pt), "Random sampler is supposed to be unbalanced"
|
||||||
weighted_sampler = get_language_weighted_sampler(train_samples)
|
weighted_sampler = get_language_weighted_sampler(train_samples)
|
||||||
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
ids = functools.reduce(lambda a, b: a + b, [list(weighted_sampler) for i in range(100)])
|
||||||
en, pt = 0, 0
|
en, pt = 0, 0
|
||||||
for id in ids:
|
for index in ids:
|
||||||
if train_samples[id][3] == 'en':
|
if train_samples[index][3] == 'en':
|
||||||
en += 1
|
en += 1
|
||||||
else:
|
else:
|
||||||
pt += 1
|
pt += 1
|
||||||
|
|
Loading…
Reference in New Issue