Update TTS.tts formatters (#1228)

* Return Dict from tts formatters

* Make style
This commit is contained in:
Eren Gölge 2022-02-11 23:03:43 +01:00 committed by GitHub
parent 5e3f499a69
commit 127118c637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 153 additions and 141 deletions

View File

@ -29,7 +29,9 @@ parser.add_argument(
help="Path to dataset config file.", help="Path to dataset config file.",
) )
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.") parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
parser.add_argument("--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None) parser.add_argument(
"--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
@ -41,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli
wav_files = meta_data_train + meta_data_eval wav_files = meta_data_train + meta_data_eval
speaker_manager = SpeakerManager( speaker_manager = SpeakerManager(
encoder_model_path=args.model_path, encoder_config_path=args.config_path, d_vectors_file_path=args.old_file, use_cuda=args.use_cuda encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=args.use_cuda,
) )
# compute speaker embeddings # compute speaker embeddings

View File

@ -51,7 +51,7 @@ def main():
N = 0 N = 0
for item in tqdm(dataset_items): for item in tqdm(dataset_items):
# compute features # compute features
wav = ap.load_wav(item if isinstance(item, str) else item[1]) wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
linear = ap.spectrogram(wav) linear = ap.spectrogram(wav)
mel = ap.melspectrogram(wav) mel = ap.melspectrogram(wav)

View File

@ -24,6 +24,7 @@ def main():
# load all datasets # load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
items = train_items + eval_items items = train_items + eval_items
texts = "".join(item[0] for item in items) texts = "".join(item[0] for item in items)

View File

@ -43,6 +43,11 @@ def main():
items = train_items + eval_items items = train_items + eval_items
print("Num items:", len(items)) print("Num items:", len(items))
is_lang_def = all(item["language"] for item in items)
if not c.phoneme_language or not is_lang_def:
raise ValueError("Phoneme language must be defined in config.")
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
phones = [] phones = []
for ph in phonemes: for ph in phonemes:

View File

@ -1,4 +1,5 @@
import os import os
import torch import torch
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config

View File

@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset):
mel = self.ap.melspectrogram(wav).astype("float32") mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len # sample seq_len
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx]["audio_file"]
sample = { sample = {
"mel": mel, "mel": mel,
"item_idx": self.items[idx][1], "item_idx": self.items[idx]["audio_file"],
"speaker_name": speaker_name, "speaker_name": speaker_name,
} }
return sample return sample
@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset):
def __parse_items(self): def __parse_items(self):
self.speaker_to_utters = {} self.speaker_to_utters = {}
for i in self.items: for i in self.items:
path_ = i[1] path_ = i["audio_file"]
speaker_ = i[2] speaker_ = i["speaker_name"]
if speaker_ in self.speaker_to_utters.keys(): if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_) self.speaker_to_utters[speaker_].append(path_)
else: else:

View File

@ -75,14 +75,14 @@ def load_tts_samples(
formatter = _get_formatter_by_name(name) formatter = _get_formatter_by_name(name)
# load train set # load train set
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
meta_data_train = [[*item, language] for item in meta_data_train] meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set # load evaluation split if set
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [[*item, language] for item in meta_data_eval] meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
else: else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval
@ -91,12 +91,12 @@ def load_tts_samples(
if dataset.meta_file_attn_mask: if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all): for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins["audio_file"]].strip()
meta_data_train_all[idx].append(attn_file) meta_data_train_all[idx].update({"alignment_file": attn_file})
if meta_data_eval_all: if meta_data_eval_all:
for idx, ins in enumerate(meta_data_eval_all): for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins["audio_file"]].strip()
meta_data_eval_all[idx].append(attn_file) meta_data_eval_all[idx].update({"alignment_file": attn_file})
# set none for the next iter # set none for the next iter
formatter = None formatter = None
return meta_data_train_all, meta_data_eval_all return meta_data_train_all, meta_data_eval_all

View File

@ -21,7 +21,7 @@ class TTSDataset(Dataset):
text_cleaner: list, text_cleaner: list,
compute_linear_spec: bool, compute_linear_spec: bool,
ap: AudioProcessor, ap: AudioProcessor,
meta_data: List[List], meta_data: List[Dict],
compute_f0: bool = False, compute_f0: bool = False,
f0_cache_path: str = None, f0_cache_path: str = None,
characters: Dict = None, characters: Dict = None,
@ -54,7 +54,7 @@ class TTSDataset(Dataset):
ap (TTS.tts.utils.AudioProcessor): Audio processor object. ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset instances. meta_data (list): List of dataset samples.
compute_f0 (bool): compute f0 if True. Defaults to False. compute_f0 (bool): compute f0 if True. Defaults to False.
@ -199,15 +199,9 @@ class TTSDataset(Dataset):
def load_data(self, idx): def load_data(self, idx):
item = self.items[idx] item = self.items[idx]
raw_text = item["text"]
if len(item) == 5: wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
text, wav_file, speaker_name, language_name, attn_file = item
else:
text, wav_file, speaker_name, language_name = item
attn = None
raw_text = text
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
# apply noise for augmentation # apply noise for augmentation
if self.use_noise_augment: if self.use_noise_augment:
@ -216,12 +210,12 @@ class TTSDataset(Dataset):
if not self.input_seq_computed: if not self.input_seq_computed:
if self.use_phonemes: if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence( text = self._load_or_generate_phoneme_sequence(
wav_file, item["audio_file"],
text, item["text"],
self.phoneme_cache_path, self.phoneme_cache_path,
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
language_name if language_name else self.phoneme_language, item["language"] if item["language"] else self.phoneme_language,
self.custom_symbols, self.custom_symbols,
self.characters, self.characters,
self.add_blank, self.add_blank,
@ -229,7 +223,7 @@ class TTSDataset(Dataset):
else: else:
text = np.asarray( text = np.asarray(
text_to_sequence( text_to_sequence(
text, item["text"],
[self.cleaners], [self.cleaners],
custom_symbols=self.custom_symbols, custom_symbols=self.custom_symbols,
tp=self.characters, tp=self.characters,
@ -238,11 +232,12 @@ class TTSDataset(Dataset):
dtype=np.int32, dtype=np.int32,
) )
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx]["audio_file"]
if "attn_file" in locals(): attn = None
attn = np.load(attn_file) if "alignment_file" in item:
attn = np.load(item["alignment_file"])
if len(text) > self.max_seq_len: if len(text) > self.max_seq_len:
# return a different sample if the phonemized # return a different sample if the phonemized
@ -252,7 +247,7 @@ class TTSDataset(Dataset):
pitch = None pitch = None
if self.compute_f0: if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = { sample = {
@ -261,10 +256,10 @@ class TTSDataset(Dataset):
"wav": wav, "wav": wav,
"pitch": pitch, "pitch": pitch,
"attn": attn, "attn": attn,
"item_idx": self.items[idx][1], "item_idx": item["audio_file"],
"speaker_name": speaker_name, "speaker_name": item["speaker_name"],
"language_name": language_name, "language_name": item["language"],
"wav_file_name": os.path.basename(wav_file), "wav_file_name": os.path.basename(item["audio_file"]),
} }
return sample return sample
@ -272,11 +267,10 @@ class TTSDataset(Dataset):
def _phoneme_worker(args): def _phoneme_worker(args):
item = args[0] item = args[0]
func_args = args[1] func_args = args[1]
text, wav_file, *_ = item
func_args[3] = ( func_args[3] = (
item[3] if item[3] else func_args[3] item["language"] if "language" in item and item["language"] else func_args[3]
) # override phoneme language if specified by the dataset formatter ) # override phoneme language if specified by the dataset formatter
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args)
return phonemes return phonemes
def compute_input_seq(self, num_workers=0): def compute_input_seq(self, num_workers=0):
@ -286,10 +280,9 @@ class TTSDataset(Dataset):
if self.verbose: if self.verbose:
print(" | > Computing input sequences ...") print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)): for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray( sequence = np.asarray(
text_to_sequence( text_to_sequence(
text, item["text"],
[self.cleaners], [self.cleaners],
custom_symbols=self.custom_symbols, custom_symbols=self.custom_symbols,
tp=self.characters, tp=self.characters,
@ -337,10 +330,10 @@ class TTSDataset(Dataset):
if by_audio_len: if by_audio_len:
lengths = [] lengths = []
for item in self.items: for item in self.items:
lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio
lengths = np.array(lengths) lengths = np.array(lengths)
else: else:
lengths = np.array([len(ins[0]) for ins in self.items]) lengths = np.array([len(ins["text"]) for ins in self.items])
idxs = np.argsort(lengths) idxs = np.argsort(lengths)
new_items = [] new_items = []
@ -555,7 +548,7 @@ class PitchExtractor:
def __init__( def __init__(
self, self,
items: List[List], items: List[Dict],
verbose=False, verbose=False,
): ):
self.items = items self.items = items
@ -614,10 +607,9 @@ class PitchExtractor:
item = args[0] item = args[0]
ap = args[1] ap = args[1]
cache_path = args[2] cache_path = args[2]
_, wav_file, *_ = item pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path)
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file): if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file)
return pitch return pitch
return None return None

View File

@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("\t") cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav") wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
wav_file = cols[1].strip() wav_file = cols[1].strip()
text = cols[0].strip() text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file) wav_file = os.path.join(root_path, "wavs", wav_file)
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
text = cols[1].strip() text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file) wav_file = os.path.join(root_path, folder_name, wav_file)
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file): if os.path.isfile(wav_file):
text = cols[1].strip() text = cols[1].strip()
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
else: else:
# M-AI-Labs have some missing samples, so just print the warning # M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file)) print("> File %s does not exist!" % (wav_file))
@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2] text = cols[2]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -133,7 +133,7 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2] text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"]) items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{idx}"})
return items return items
@ -150,7 +150,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
if not os.path.exists(wav_file): if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...") print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue continue
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -165,7 +165,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -179,7 +179,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, cols[0]) wav_file = os.path.join(root_path, cols[0])
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -193,7 +193,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
utt_id = line.split()[1] utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1] text = line[line.find('"') + 1 : line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -213,7 +213,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append([text, wav_file, "MCV_" + speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
return items return items
@ -240,7 +240,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
items.append([text, wav_file, "LTTS_" + speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
for item in items: for item in items:
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
return items return items
@ -259,7 +259,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
skipped_files.append(wav_file) skipped_files.append(wav_file)
continue continue
text = cols[1].strip() text = cols[1].strip()
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...") print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
return items return items
@ -281,7 +281,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers: if speaker_id in ignored_speakers:
continue continue
items.append([text, wav_file, speaker_id]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
return items return items
@ -299,7 +299,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
with open(meta_file, "r", encoding="utf-8") as file_text: with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0] text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id]) items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
return items return items
@ -334,7 +334,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker in ignored_speakers: if speaker in ignored_speakers:
continue continue
items.append([text, wav_file, "MLS_" + speaker]) items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
return items return items
@ -404,7 +404,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
for line in ttf: for line in ttf:
wav_name, text = line.rstrip("\n").split("|") wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name) wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append([text, wav_path, speaker_name]) items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
return items return items
@ -418,5 +418,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2].replace(" ", "") text = cols[2].replace(" ", "")
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items

View File

@ -4,7 +4,6 @@ from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
import torchaudio import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
@ -692,10 +691,17 @@ class Vits(BaseTTS):
if self.args.use_sdp: if self.args.use_sdp:
logw = self.duration_predictor( logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb x,
x_mask,
g=g if self.args.condition_dp_on_speaker else None,
reverse=True,
noise_scale=self.inference_noise_scale_dp,
lang_emb=lang_emb,
) )
else: else:
logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb) logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)

View File

@ -113,7 +113,7 @@ def _set_file_path(path):
def get_language_weighted_sampler(items: list): def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items]) language_names = np.array([item["language"] for item in items])
unique_language_names = np.unique(language_names).tolist() unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names] language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])

View File

@ -118,7 +118,7 @@ class SpeakerManager:
Returns: Returns:
Tuple[Dict, int]: speaker IDs and number of speakers. Tuple[Dict, int]: speaker IDs and number of speakers.
""" """
speakers = sorted({item[2] for item in items}) speakers = sorted({item["speaker_name"] for item in items})
speaker_ids = {name: i for i, name in enumerate(speakers)} speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids) num_speakers = len(speaker_ids)
return speaker_ids, num_speakers return speaker_ids, num_speakers
@ -414,7 +414,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
def get_speaker_weighted_sampler(items: list): def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items]) speaker_names = np.array([item["speaker_name"] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist() unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])

View File

@ -127,5 +127,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) lr_scheduler_disc_params: dict = field(
default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}
)
scheduler_after_epoch: bool = False scheduler_after_epoch: bool = False

View File

@ -5,13 +5,13 @@ from tests import get_tests_input_path
from TTS.tts.datasets.formatters import common_voice from TTS.tts.datasets.formatters import common_voice
class TestPreprocessors(unittest.TestCase): class TestTTSFormatters(unittest.TestCase):
def test_common_voice_preprocessor(self): # pylint: disable=no-self-use def test_common_voice_preprocessor(self): # pylint: disable=no-self-use
root_path = get_tests_input_path() root_path = get_tests_input_path()
meta_file = "common_voice.tsv" meta_file = "common_voice.tsv"
items = common_voice(root_path, meta_file) items = common_voice(root_path, meta_file)
assert items[0][0] == "The applicants are invited for coffee and visa is given immediately." assert items[0]["text"] == "The applicants are invited for coffee and visa is given immediately."
assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav") assert items[0]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav")
assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts." assert items[-1]["text"] == "Competition for limited resources has also resulted in some local conflicts."
assert items[-1][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav") assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")