mirror of https://github.com/coqui-ai/TTS.git
Update TTS.tts formatters (#1228)
* Return Dict from tts formatters * Make style
This commit is contained in:
parent
5e3f499a69
commit
127118c637
|
@ -29,7 +29,9 @@ parser.add_argument(
|
|||
help="Path to dataset config file.",
|
||||
)
|
||||
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
|
||||
parser.add_argument("--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None)
|
||||
parser.add_argument(
|
||||
"--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
|
||||
|
@ -41,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli
|
|||
wav_files = meta_data_train + meta_data_eval
|
||||
|
||||
speaker_manager = SpeakerManager(
|
||||
encoder_model_path=args.model_path, encoder_config_path=args.config_path, d_vectors_file_path=args.old_file, use_cuda=args.use_cuda
|
||||
encoder_model_path=args.model_path,
|
||||
encoder_config_path=args.config_path,
|
||||
d_vectors_file_path=args.old_file,
|
||||
use_cuda=args.use_cuda,
|
||||
)
|
||||
|
||||
# compute speaker embeddings
|
||||
|
|
|
@ -51,7 +51,7 @@ def main():
|
|||
N = 0
|
||||
for item in tqdm(dataset_items):
|
||||
# compute features
|
||||
wav = ap.load_wav(item if isinstance(item, str) else item[1])
|
||||
wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
|
||||
linear = ap.spectrogram(wav)
|
||||
mel = ap.melspectrogram(wav)
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ def main():
|
|||
|
||||
# load all datasets
|
||||
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
|
||||
|
||||
items = train_items + eval_items
|
||||
|
||||
texts = "".join(item[0] for item in items)
|
||||
|
|
|
@ -43,6 +43,11 @@ def main():
|
|||
items = train_items + eval_items
|
||||
print("Num items:", len(items))
|
||||
|
||||
is_lang_def = all(item["language"] for item in items)
|
||||
|
||||
if not c.phoneme_language or not is_lang_def:
|
||||
raise ValueError("Phoneme language must be defined in config.")
|
||||
|
||||
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
|
||||
phones = []
|
||||
for ph in phonemes:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config
|
||||
|
|
|
@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset):
|
|||
mel = self.ap.melspectrogram(wav).astype("float32")
|
||||
# sample seq_len
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
assert wav.size > 0, self.items[idx][1]
|
||||
assert text.size > 0, self.items[idx]["audio_file"]
|
||||
assert wav.size > 0, self.items[idx]["audio_file"]
|
||||
|
||||
sample = {
|
||||
"mel": mel,
|
||||
"item_idx": self.items[idx][1],
|
||||
"item_idx": self.items[idx]["audio_file"],
|
||||
"speaker_name": speaker_name,
|
||||
}
|
||||
return sample
|
||||
|
@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset):
|
|||
def __parse_items(self):
|
||||
self.speaker_to_utters = {}
|
||||
for i in self.items:
|
||||
path_ = i[1]
|
||||
speaker_ = i[2]
|
||||
path_ = i["audio_file"]
|
||||
speaker_ = i["speaker_name"]
|
||||
if speaker_ in self.speaker_to_utters.keys():
|
||||
self.speaker_to_utters[speaker_].append(path_)
|
||||
else:
|
||||
|
|
|
@ -75,14 +75,14 @@ def load_tts_samples(
|
|||
formatter = _get_formatter_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
|
||||
meta_data_train = [[*item, language] for item in meta_data_train]
|
||||
meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
|
||||
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
|
||||
meta_data_eval = [[*item, language] for item in meta_data_eval]
|
||||
meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
|
||||
else:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
|
@ -91,12 +91,12 @@ def load_tts_samples(
|
|||
if dataset.meta_file_attn_mask:
|
||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||
for idx, ins in enumerate(meta_data_train_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_train_all[idx].append(attn_file)
|
||||
attn_file = meta_data[ins["audio_file"]].strip()
|
||||
meta_data_train_all[idx].update({"alignment_file": attn_file})
|
||||
if meta_data_eval_all:
|
||||
for idx, ins in enumerate(meta_data_eval_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_eval_all[idx].append(attn_file)
|
||||
attn_file = meta_data[ins["audio_file"]].strip()
|
||||
meta_data_eval_all[idx].update({"alignment_file": attn_file})
|
||||
# set none for the next iter
|
||||
formatter = None
|
||||
return meta_data_train_all, meta_data_eval_all
|
||||
|
|
|
@ -21,7 +21,7 @@ class TTSDataset(Dataset):
|
|||
text_cleaner: list,
|
||||
compute_linear_spec: bool,
|
||||
ap: AudioProcessor,
|
||||
meta_data: List[List],
|
||||
meta_data: List[Dict],
|
||||
compute_f0: bool = False,
|
||||
f0_cache_path: str = None,
|
||||
characters: Dict = None,
|
||||
|
@ -54,7 +54,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
||||
|
||||
meta_data (list): List of dataset instances.
|
||||
meta_data (list): List of dataset samples.
|
||||
|
||||
compute_f0 (bool): compute f0 if True. Defaults to False.
|
||||
|
||||
|
@ -199,15 +199,9 @@ class TTSDataset(Dataset):
|
|||
|
||||
def load_data(self, idx):
|
||||
item = self.items[idx]
|
||||
raw_text = item["text"]
|
||||
|
||||
if len(item) == 5:
|
||||
text, wav_file, speaker_name, language_name, attn_file = item
|
||||
else:
|
||||
text, wav_file, speaker_name, language_name = item
|
||||
attn = None
|
||||
raw_text = text
|
||||
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
|
||||
|
||||
# apply noise for augmentation
|
||||
if self.use_noise_augment:
|
||||
|
@ -216,12 +210,12 @@ class TTSDataset(Dataset):
|
|||
if not self.input_seq_computed:
|
||||
if self.use_phonemes:
|
||||
text = self._load_or_generate_phoneme_sequence(
|
||||
wav_file,
|
||||
text,
|
||||
item["audio_file"],
|
||||
item["text"],
|
||||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
language_name if language_name else self.phoneme_language,
|
||||
item["language"] if item["language"] else self.phoneme_language,
|
||||
self.custom_symbols,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
|
@ -229,7 +223,7 @@ class TTSDataset(Dataset):
|
|||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(
|
||||
text,
|
||||
item["text"],
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
|
@ -238,11 +232,12 @@ class TTSDataset(Dataset):
|
|||
dtype=np.int32,
|
||||
)
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
assert wav.size > 0, self.items[idx][1]
|
||||
assert text.size > 0, self.items[idx]["audio_file"]
|
||||
assert wav.size > 0, self.items[idx]["audio_file"]
|
||||
|
||||
if "attn_file" in locals():
|
||||
attn = np.load(attn_file)
|
||||
attn = None
|
||||
if "alignment_file" in item:
|
||||
attn = np.load(item["alignment_file"])
|
||||
|
||||
if len(text) > self.max_seq_len:
|
||||
# return a different sample if the phonemized
|
||||
|
@ -252,7 +247,7 @@ class TTSDataset(Dataset):
|
|||
|
||||
pitch = None
|
||||
if self.compute_f0:
|
||||
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path)
|
||||
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path)
|
||||
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
|
||||
|
||||
sample = {
|
||||
|
@ -261,10 +256,10 @@ class TTSDataset(Dataset):
|
|||
"wav": wav,
|
||||
"pitch": pitch,
|
||||
"attn": attn,
|
||||
"item_idx": self.items[idx][1],
|
||||
"speaker_name": speaker_name,
|
||||
"language_name": language_name,
|
||||
"wav_file_name": os.path.basename(wav_file),
|
||||
"item_idx": item["audio_file"],
|
||||
"speaker_name": item["speaker_name"],
|
||||
"language_name": item["language"],
|
||||
"wav_file_name": os.path.basename(item["audio_file"]),
|
||||
}
|
||||
return sample
|
||||
|
||||
|
@ -272,11 +267,10 @@ class TTSDataset(Dataset):
|
|||
def _phoneme_worker(args):
|
||||
item = args[0]
|
||||
func_args = args[1]
|
||||
text, wav_file, *_ = item
|
||||
func_args[3] = (
|
||||
item[3] if item[3] else func_args[3]
|
||||
item["language"] if "language" in item and item["language"] else func_args[3]
|
||||
) # override phoneme language if specified by the dataset formatter
|
||||
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||
phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args)
|
||||
return phonemes
|
||||
|
||||
def compute_input_seq(self, num_workers=0):
|
||||
|
@ -286,10 +280,9 @@ class TTSDataset(Dataset):
|
|||
if self.verbose:
|
||||
print(" | > Computing input sequences ...")
|
||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
text, *_ = item
|
||||
sequence = np.asarray(
|
||||
text_to_sequence(
|
||||
text,
|
||||
item["text"],
|
||||
[self.cleaners],
|
||||
custom_symbols=self.custom_symbols,
|
||||
tp=self.characters,
|
||||
|
@ -337,10 +330,10 @@ class TTSDataset(Dataset):
|
|||
if by_audio_len:
|
||||
lengths = []
|
||||
for item in self.items:
|
||||
lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio
|
||||
lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio
|
||||
lengths = np.array(lengths)
|
||||
else:
|
||||
lengths = np.array([len(ins[0]) for ins in self.items])
|
||||
lengths = np.array([len(ins["text"]) for ins in self.items])
|
||||
|
||||
idxs = np.argsort(lengths)
|
||||
new_items = []
|
||||
|
@ -555,7 +548,7 @@ class PitchExtractor:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
items: List[List],
|
||||
items: List[Dict],
|
||||
verbose=False,
|
||||
):
|
||||
self.items = items
|
||||
|
@ -614,10 +607,9 @@ class PitchExtractor:
|
|||
item = args[0]
|
||||
ap = args[1]
|
||||
cache_path = args[2]
|
||||
_, wav_file, *_ = item
|
||||
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
|
||||
pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path)
|
||||
if not os.path.exists(pitch_file):
|
||||
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file)
|
||||
pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file)
|
||||
return pitch
|
||||
return None
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("\t")
|
||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
wav_file = cols[1].strip()
|
||||
text = cols[0].strip()
|
||||
wav_file = os.path.join(root_path, "wavs", wav_file)
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
|
|||
text = cols[1].strip()
|
||||
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
||||
wav_file = os.path.join(root_path, folder_name, wav_file)
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
|
|||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||
if os.path.isfile(wav_file):
|
||||
text = cols[1].strip()
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
else:
|
||||
# M-AI-Labs have some missing samples, so just print the warning
|
||||
print("> File %s does not exist!" % (wav_file))
|
||||
|
@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -133,7 +133,7 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2]
|
||||
items.append([text, wav_file, f"ljspeech-{idx}"])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{idx}"})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -150,7 +150,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
|
|||
if not os.path.exists(wav_file):
|
||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||
continue
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -165,7 +165,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -179,7 +179,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, cols[0])
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -193,7 +193,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
utt_id = line.split()[1]
|
||||
text = line[line.find('"') + 1 : line.rfind('"') - 1]
|
||||
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -213,7 +213,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
|
|||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||
items.append([text, wav_file, "MCV_" + speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -240,7 +240,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_name in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, "LTTS_" + speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
|
||||
for item in items:
|
||||
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||
return items
|
||||
|
@ -259,7 +259,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
|
|||
skipped_files.append(wav_file)
|
||||
continue
|
||||
text = cols[1].strip()
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
|
||||
return items
|
||||
|
||||
|
@ -281,7 +281,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker_id in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, speaker_id])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -299,7 +299,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
|
|||
with open(meta_file, "r", encoding="utf-8") as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append([text, wav_file, "VCTK_" + speaker_id])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
|
||||
|
||||
return items
|
||||
|
||||
|
@ -334,7 +334,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
|
|||
if isinstance(ignored_speakers, list):
|
||||
if speaker in ignored_speakers:
|
||||
continue
|
||||
items.append([text, wav_file, "MLS_" + speaker])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -404,7 +404,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
|
|||
for line in ttf:
|
||||
wav_name, text = line.rstrip("\n").split("|")
|
||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||
items.append([text, wav_path, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
||||
|
||||
|
@ -418,5 +418,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
|
|||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[2].replace(" ", "")
|
||||
items.append([text, wav_file, speaker_name])
|
||||
items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
|
||||
return items
|
||||
|
|
|
@ -4,7 +4,6 @@ from itertools import chain
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import torchaudio
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -692,10 +691,17 @@ class Vits(BaseTTS):
|
|||
|
||||
if self.args.use_sdp:
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
||||
x,
|
||||
x_mask,
|
||||
g=g if self.args.condition_dp_on_speaker else None,
|
||||
reverse=True,
|
||||
noise_scale=self.inference_noise_scale_dp,
|
||||
lang_emb=lang_emb,
|
||||
)
|
||||
else:
|
||||
logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb)
|
||||
logw = self.duration_predictor(
|
||||
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
|
||||
)
|
||||
|
||||
w = torch.exp(logw) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
|
|
|
@ -113,7 +113,7 @@ def _set_file_path(path):
|
|||
|
||||
|
||||
def get_language_weighted_sampler(items: list):
|
||||
language_names = np.array([item[3] for item in items])
|
||||
language_names = np.array([item["language"] for item in items])
|
||||
unique_language_names = np.unique(language_names).tolist()
|
||||
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||
|
|
|
@ -118,7 +118,7 @@ class SpeakerManager:
|
|||
Returns:
|
||||
Tuple[Dict, int]: speaker IDs and number of speakers.
|
||||
"""
|
||||
speakers = sorted({item[2] for item in items})
|
||||
speakers = sorted({item["speaker_name"] for item in items})
|
||||
speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
num_speakers = len(speaker_ids)
|
||||
return speaker_ids, num_speakers
|
||||
|
@ -414,7 +414,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
|||
|
||||
|
||||
def get_speaker_weighted_sampler(items: list):
|
||||
speaker_names = np.array([item[2] for item in items])
|
||||
speaker_names = np.array([item["speaker_name"] for item in items])
|
||||
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||
|
|
|
@ -127,5 +127,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
|
|||
lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
|
||||
lr_scheduler_disc_params: dict = field(
|
||||
default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}
|
||||
)
|
||||
scheduler_after_epoch: bool = False
|
||||
|
|
|
@ -5,13 +5,13 @@ from tests import get_tests_input_path
|
|||
from TTS.tts.datasets.formatters import common_voice
|
||||
|
||||
|
||||
class TestPreprocessors(unittest.TestCase):
|
||||
class TestTTSFormatters(unittest.TestCase):
|
||||
def test_common_voice_preprocessor(self): # pylint: disable=no-self-use
|
||||
root_path = get_tests_input_path()
|
||||
meta_file = "common_voice.tsv"
|
||||
items = common_voice(root_path, meta_file)
|
||||
assert items[0][0] == "The applicants are invited for coffee and visa is given immediately."
|
||||
assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav")
|
||||
assert items[0]["text"] == "The applicants are invited for coffee and visa is given immediately."
|
||||
assert items[0]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav")
|
||||
|
||||
assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts."
|
||||
assert items[-1][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")
|
||||
assert items[-1]["text"] == "Competition for limited resources has also resulted in some local conflicts."
|
||||
assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")
|
||||
|
|
Loading…
Reference in New Issue