Update TTS.tts formatters (#1228)

* Return Dict from tts formatters

* Make style
This commit is contained in:
Eren Gölge 2022-02-11 23:03:43 +01:00 committed by GitHub
parent 5e3f499a69
commit 127118c637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 153 additions and 141 deletions

View File

@ -29,7 +29,9 @@ parser.add_argument(
help="Path to dataset config file.", help="Path to dataset config file.",
) )
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.") parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
parser.add_argument("--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None) parser.add_argument(
"--old_file", type=str, help="Previous speakers.json file, only compute for new audios.", default=None
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
parser.add_argument("--eval", type=bool, help="compute eval.", default=True) parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
@ -41,7 +43,10 @@ meta_data_train, meta_data_eval = load_tts_samples(c_dataset.datasets, eval_spli
wav_files = meta_data_train + meta_data_eval wav_files = meta_data_train + meta_data_eval
speaker_manager = SpeakerManager( speaker_manager = SpeakerManager(
encoder_model_path=args.model_path, encoder_config_path=args.config_path, d_vectors_file_path=args.old_file, use_cuda=args.use_cuda encoder_model_path=args.model_path,
encoder_config_path=args.config_path,
d_vectors_file_path=args.old_file,
use_cuda=args.use_cuda,
) )
# compute speaker embeddings # compute speaker embeddings

View File

@ -51,7 +51,7 @@ def main():
N = 0 N = 0
for item in tqdm(dataset_items): for item in tqdm(dataset_items):
# compute features # compute features
wav = ap.load_wav(item if isinstance(item, str) else item[1]) wav = ap.load_wav(item if isinstance(item, str) else item["audio_file"])
linear = ap.spectrogram(wav) linear = ap.spectrogram(wav)
mel = ap.melspectrogram(wav) mel = ap.melspectrogram(wav)
@ -59,13 +59,13 @@ def main():
N += mel.shape[1] N += mel.shape[1]
mel_sum += mel.sum(1) mel_sum += mel.sum(1)
linear_sum += linear.sum(1) linear_sum += linear.sum(1)
mel_square_sum += (mel ** 2).sum(axis=1) mel_square_sum += (mel**2).sum(axis=1)
linear_square_sum += (linear ** 2).sum(axis=1) linear_square_sum += (linear**2).sum(axis=1)
mel_mean = mel_sum / N mel_mean = mel_sum / N
mel_scale = np.sqrt(mel_square_sum / N - mel_mean ** 2) mel_scale = np.sqrt(mel_square_sum / N - mel_mean**2)
linear_mean = linear_sum / N linear_mean = linear_sum / N
linear_scale = np.sqrt(linear_square_sum / N - linear_mean ** 2) linear_scale = np.sqrt(linear_square_sum / N - linear_mean**2)
output_file_path = args.out_path output_file_path = args.out_path
stats = {} stats = {}

View File

@ -24,6 +24,7 @@ def main():
# load all datasets # load all datasets
train_items, eval_items = load_tts_samples(c.datasets, eval_split=True) train_items, eval_items = load_tts_samples(c.datasets, eval_split=True)
items = train_items + eval_items items = train_items + eval_items
texts = "".join(item[0] for item in items) texts = "".join(item[0] for item in items)

View File

@ -43,6 +43,11 @@ def main():
items = train_items + eval_items items = train_items + eval_items
print("Num items:", len(items)) print("Num items:", len(items))
is_lang_def = all(item["language"] for item in items)
if not c.phoneme_language or not is_lang_def:
raise ValueError("Phoneme language must be defined in config.")
phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15) phonemes = process_map(compute_phonemes, items, max_workers=multiprocessing.cpu_count(), chunksize=15)
phones = [] phones = []
for ph in phonemes: for ph in phonemes:

View File

@ -1,4 +1,5 @@
import os import os
import torch import torch
from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config

View File

@ -78,12 +78,12 @@ class SpeakerEncoderDataset(Dataset):
mel = self.ap.melspectrogram(wav).astype("float32") mel = self.ap.melspectrogram(wav).astype("float32")
# sample seq_len # sample seq_len
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx]["audio_file"]
sample = { sample = {
"mel": mel, "mel": mel,
"item_idx": self.items[idx][1], "item_idx": self.items[idx]["audio_file"],
"speaker_name": speaker_name, "speaker_name": speaker_name,
} }
return sample return sample
@ -91,8 +91,8 @@ class SpeakerEncoderDataset(Dataset):
def __parse_items(self): def __parse_items(self):
self.speaker_to_utters = {} self.speaker_to_utters = {}
for i in self.items: for i in self.items:
path_ = i[1] path_ = i["audio_file"]
speaker_ = i[2] speaker_ = i["speaker_name"]
if speaker_ in self.speaker_to_utters.keys(): if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_) self.speaker_to_utters[speaker_].append(path_)
else: else:

View File

@ -229,7 +229,7 @@ class ResNetSpeakerEncoder(nn.Module):
x = torch.sum(x * w, dim=2) x = torch.sum(x * w, dim=2)
elif self.encoder_type == "ASP": elif self.encoder_type == "ASP":
mu = torch.sum(x * w, dim=2) mu = torch.sum(x * w, dim=2)
sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
x = torch.cat((mu, sg), 1) x = torch.cat((mu, sg), 1)
x = x.view(x.size()[0], -1) x = x.view(x.size()[0], -1)

View File

@ -113,7 +113,7 @@ class AugmentWAV(object):
def additive_noise(self, noise_type, audio): def additive_noise(self, noise_type, audio):
clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4) clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4)
noise_list = random.sample( noise_list = random.sample(
self.noise_list[noise_type], self.noise_list[noise_type],
@ -135,7 +135,7 @@ class AugmentWAV(object):
self.additive_noise_config[noise_type]["min_snr_in_db"], self.additive_noise_config[noise_type]["min_snr_in_db"],
self.additive_noise_config[noise_type]["max_num_noises"], self.additive_noise_config[noise_type]["max_num_noises"],
) )
noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4) noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4)
noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio
if noises_wav is None: if noises_wav is None:
@ -154,7 +154,7 @@ class AugmentWAV(object):
rir_file = random.choice(self.rir_files) rir_file = random.choice(self.rir_files)
rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate)
rir = rir / np.sqrt(np.sum(rir ** 2)) rir = rir / np.sqrt(np.sum(rir**2))
return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len]
def apply_one(self, audio): def apply_one(self, audio):

View File

@ -75,14 +75,14 @@ def load_tts_samples(
formatter = _get_formatter_by_name(name) formatter = _get_formatter_by_name(name)
# load train set # load train set
meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers) meta_data_train = formatter(root_path, meta_file_train, ignored_speakers=ignored_speakers)
meta_data_train = [[*item, language] for item in meta_data_train] meta_data_train = [{**item, **{"language": language}} for item in meta_data_train]
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}") print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
# load evaluation split if set # load evaluation split if set
if eval_split: if eval_split:
if meta_file_val: if meta_file_val:
meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers) meta_data_eval = formatter(root_path, meta_file_val, ignored_speakers=ignored_speakers)
meta_data_eval = [[*item, language] for item in meta_data_eval] meta_data_eval = [{**item, **{"language": language}} for item in meta_data_eval]
else: else:
meta_data_eval, meta_data_train = split_dataset(meta_data_train) meta_data_eval, meta_data_train = split_dataset(meta_data_train)
meta_data_eval_all += meta_data_eval meta_data_eval_all += meta_data_eval
@ -91,12 +91,12 @@ def load_tts_samples(
if dataset.meta_file_attn_mask: if dataset.meta_file_attn_mask:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all): for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins["audio_file"]].strip()
meta_data_train_all[idx].append(attn_file) meta_data_train_all[idx].update({"alignment_file": attn_file})
if meta_data_eval_all: if meta_data_eval_all:
for idx, ins in enumerate(meta_data_eval_all): for idx, ins in enumerate(meta_data_eval_all):
attn_file = meta_data[ins[1]].strip() attn_file = meta_data[ins["audio_file"]].strip()
meta_data_eval_all[idx].append(attn_file) meta_data_eval_all[idx].update({"alignment_file": attn_file})
# set none for the next iter # set none for the next iter
formatter = None formatter = None
return meta_data_train_all, meta_data_eval_all return meta_data_train_all, meta_data_eval_all

View File

@ -21,7 +21,7 @@ class TTSDataset(Dataset):
text_cleaner: list, text_cleaner: list,
compute_linear_spec: bool, compute_linear_spec: bool,
ap: AudioProcessor, ap: AudioProcessor,
meta_data: List[List], meta_data: List[Dict],
compute_f0: bool = False, compute_f0: bool = False,
f0_cache_path: str = None, f0_cache_path: str = None,
characters: Dict = None, characters: Dict = None,
@ -54,7 +54,7 @@ class TTSDataset(Dataset):
ap (TTS.tts.utils.AudioProcessor): Audio processor object. ap (TTS.tts.utils.AudioProcessor): Audio processor object.
meta_data (list): List of dataset instances. meta_data (list): List of dataset samples.
compute_f0 (bool): compute f0 if True. Defaults to False. compute_f0 (bool): compute f0 if True. Defaults to False.
@ -199,15 +199,9 @@ class TTSDataset(Dataset):
def load_data(self, idx): def load_data(self, idx):
item = self.items[idx] item = self.items[idx]
raw_text = item["text"]
if len(item) == 5: wav = np.asarray(self.load_wav(item["audio_file"]), dtype=np.float32)
text, wav_file, speaker_name, language_name, attn_file = item
else:
text, wav_file, speaker_name, language_name = item
attn = None
raw_text = text
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
# apply noise for augmentation # apply noise for augmentation
if self.use_noise_augment: if self.use_noise_augment:
@ -216,12 +210,12 @@ class TTSDataset(Dataset):
if not self.input_seq_computed: if not self.input_seq_computed:
if self.use_phonemes: if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence( text = self._load_or_generate_phoneme_sequence(
wav_file, item["audio_file"],
text, item["text"],
self.phoneme_cache_path, self.phoneme_cache_path,
self.enable_eos_bos, self.enable_eos_bos,
self.cleaners, self.cleaners,
language_name if language_name else self.phoneme_language, item["language"] if item["language"] else self.phoneme_language,
self.custom_symbols, self.custom_symbols,
self.characters, self.characters,
self.add_blank, self.add_blank,
@ -229,7 +223,7 @@ class TTSDataset(Dataset):
else: else:
text = np.asarray( text = np.asarray(
text_to_sequence( text_to_sequence(
text, item["text"],
[self.cleaners], [self.cleaners],
custom_symbols=self.custom_symbols, custom_symbols=self.custom_symbols,
tp=self.characters, tp=self.characters,
@ -238,11 +232,12 @@ class TTSDataset(Dataset):
dtype=np.int32, dtype=np.int32,
) )
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx]["audio_file"]
assert wav.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx]["audio_file"]
if "attn_file" in locals(): attn = None
attn = np.load(attn_file) if "alignment_file" in item:
attn = np.load(item["alignment_file"])
if len(text) > self.max_seq_len: if len(text) > self.max_seq_len:
# return a different sample if the phonemized # return a different sample if the phonemized
@ -252,7 +247,7 @@ class TTSDataset(Dataset):
pitch = None pitch = None
if self.compute_f0: if self.compute_f0:
pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, wav_file, self.f0_cache_path) pitch = self.pitch_extractor.load_or_compute_pitch(self.ap, item["audio_file"], self.f0_cache_path)
pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32)) pitch = self.pitch_extractor.normalize_pitch(pitch.astype(np.float32))
sample = { sample = {
@ -261,10 +256,10 @@ class TTSDataset(Dataset):
"wav": wav, "wav": wav,
"pitch": pitch, "pitch": pitch,
"attn": attn, "attn": attn,
"item_idx": self.items[idx][1], "item_idx": item["audio_file"],
"speaker_name": speaker_name, "speaker_name": item["speaker_name"],
"language_name": language_name, "language_name": item["language"],
"wav_file_name": os.path.basename(wav_file), "wav_file_name": os.path.basename(item["audio_file"]),
} }
return sample return sample
@ -272,11 +267,10 @@ class TTSDataset(Dataset):
def _phoneme_worker(args): def _phoneme_worker(args):
item = args[0] item = args[0]
func_args = args[1] func_args = args[1]
text, wav_file, *_ = item
func_args[3] = ( func_args[3] = (
item[3] if item[3] else func_args[3] item["language"] if "language" in item and item["language"] else func_args[3]
) # override phoneme language if specified by the dataset formatter ) # override phoneme language if specified by the dataset formatter
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) phonemes = TTSDataset._load_or_generate_phoneme_sequence(item["audio_file"], item["text"], *func_args)
return phonemes return phonemes
def compute_input_seq(self, num_workers=0): def compute_input_seq(self, num_workers=0):
@ -286,10 +280,9 @@ class TTSDataset(Dataset):
if self.verbose: if self.verbose:
print(" | > Computing input sequences ...") print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)): for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray( sequence = np.asarray(
text_to_sequence( text_to_sequence(
text, item["text"],
[self.cleaners], [self.cleaners],
custom_symbols=self.custom_symbols, custom_symbols=self.custom_symbols,
tp=self.characters, tp=self.characters,
@ -337,10 +330,10 @@ class TTSDataset(Dataset):
if by_audio_len: if by_audio_len:
lengths = [] lengths = []
for item in self.items: for item in self.items:
lengths.append(os.path.getsize(item[1]) / 16 * 8) # assuming 16bit audio lengths.append(os.path.getsize(item["audio_file"]) / 16 * 8) # assuming 16bit audio
lengths = np.array(lengths) lengths = np.array(lengths)
else: else:
lengths = np.array([len(ins[0]) for ins in self.items]) lengths = np.array([len(ins["text"]) for ins in self.items])
idxs = np.argsort(lengths) idxs = np.argsort(lengths)
new_items = [] new_items = []
@ -555,7 +548,7 @@ class PitchExtractor:
def __init__( def __init__(
self, self,
items: List[List], items: List[Dict],
verbose=False, verbose=False,
): ):
self.items = items self.items = items
@ -614,10 +607,9 @@ class PitchExtractor:
item = args[0] item = args[0]
ap = args[1] ap = args[1]
cache_path = args[2] cache_path = args[2]
_, wav_file, *_ = item pitch_file = PitchExtractor.create_pitch_file_path(item["audio_file"], cache_path)
pitch_file = PitchExtractor.create_pitch_file_path(wav_file, cache_path)
if not os.path.exists(pitch_file): if not os.path.exists(pitch_file):
pitch = PitchExtractor._compute_and_save_pitch(ap, wav_file, pitch_file) pitch = PitchExtractor._compute_and_save_pitch(ap, item["audio_file"], pitch_file)
return pitch return pitch
return None return None

View File

@ -24,7 +24,7 @@ def tweb(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("\t") cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav") wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -39,7 +39,7 @@ def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
wav_file = cols[1].strip() wav_file = cols[1].strip()
text = cols[0].strip() text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file) wav_file = os.path.join(root_path, "wavs", wav_file)
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -55,7 +55,7 @@ def mozilla_de(root_path, meta_file, **kwargs): # pylint: disable=unused-argume
text = cols[1].strip() text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file) wav_file = os.path.join(root_path, folder_name, wav_file)
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -101,7 +101,7 @@ def mailabs(root_path, meta_files=None, ignored_speakers=None):
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file): if os.path.isfile(wav_file):
text = cols[1].strip() text = cols[1].strip()
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
else: else:
# M-AI-Labs have some missing samples, so just print the warning # M-AI-Labs have some missing samples, so just print the warning
print("> File %s does not exist!" % (wav_file)) print("> File %s does not exist!" % (wav_file))
@ -119,7 +119,7 @@ def ljspeech(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2] text = cols[2]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -133,7 +133,7 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2] text = cols[2]
items.append([text, wav_file, f"ljspeech-{idx}"]) items.append({"text": text, "audio_file": wav_file, "speaker_name": f"ljspeech-{idx}"})
return items return items
@ -150,7 +150,7 @@ def sam_accenture(root_path, meta_file, **kwargs): # pylint: disable=unused-arg
if not os.path.exists(wav_file): if not os.path.exists(wav_file):
print(f" [!] {wav_file} in metafile does not exist. Skipping...") print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue continue
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -165,7 +165,7 @@ def ruslan(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -179,7 +179,7 @@ def css10(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, cols[0]) wav_file = os.path.join(root_path, cols[0])
text = cols[1] text = cols[1]
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -193,7 +193,7 @@ def nancy(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
utt_id = line.split()[1] utt_id = line.split()[1]
text = line[line.find('"') + 1 : line.rfind('"') - 1] text = line[line.find('"') + 1 : line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items
@ -213,7 +213,7 @@ def common_voice(root_path, meta_file, ignored_speakers=None):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append([text, wav_file, "MCV_" + speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": "MCV_" + speaker_name})
return items return items
@ -240,7 +240,7 @@ def libri_tts(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_name in ignored_speakers: if speaker_name in ignored_speakers:
continue continue
items.append([text, wav_file, "LTTS_" + speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": f"LTTS_{speaker_name}"})
for item in items: for item in items:
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
return items return items
@ -259,7 +259,7 @@ def custom_turkish(root_path, meta_file, **kwargs): # pylint: disable=unused-ar
skipped_files.append(wav_file) skipped_files.append(wav_file)
continue continue
text = cols[1].strip() text = cols[1].strip()
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
print(f" [!] {len(skipped_files)} files skipped. They don't exist...") print(f" [!] {len(skipped_files)} files skipped. They don't exist...")
return items return items
@ -281,7 +281,7 @@ def brspeech(root_path, meta_file, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker_id in ignored_speakers: if speaker_id in ignored_speakers:
continue continue
items.append([text, wav_file, speaker_id]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_id})
return items return items
@ -299,7 +299,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ignored_speakers=None):
with open(meta_file, "r", encoding="utf-8") as file_text: with open(meta_file, "r", encoding="utf-8") as file_text:
text = file_text.readlines()[0] text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id]) items.append({"text": text, "audio_file": wav_file, "speaker_name": "VCTK_" + speaker_id})
return items return items
@ -334,7 +334,7 @@ def mls(root_path, meta_files=None, ignored_speakers=None):
if isinstance(ignored_speakers, list): if isinstance(ignored_speakers, list):
if speaker in ignored_speakers: if speaker in ignored_speakers:
continue continue
items.append([text, wav_file, "MLS_" + speaker]) items.append({"text": text, "audio_file": wav_file, "speaker_name": "MLS_" + speaker})
return items return items
@ -404,7 +404,7 @@ def baker(root_path: str, meta_file: str, **kwargs) -> List[List[str]]: # pylin
for line in ttf: for line in ttf:
wav_name, text = line.rstrip("\n").split("|") wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name) wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append([text, wav_path, speaker_name]) items.append({"text": text, "audio_file": wav_path, "speaker_name": speaker_name})
return items return items
@ -418,5 +418,5 @@ def kokoro(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
cols = line.split("|") cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[2].replace(" ", "") text = cols[2].replace(" ", "")
items.append([text, wav_file, speaker_name]) items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name})
return items return items

View File

@ -113,7 +113,7 @@ class ActNorm(nn.Module):
denom = torch.sum(x_mask, [0, 2]) denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m ** 2) v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)

View File

@ -65,7 +65,7 @@ class WN(torch.nn.Module):
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
# intermediate layers # intermediate layers
for i in range(num_layers): for i in range(num_layers):
dilation = dilation_rate ** i dilation = dilation_rate**i
padding = int((kernel_size * dilation - dilation) / 2) padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d( in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding

View File

@ -101,7 +101,7 @@ class Encoder(nn.Module):
self.encoder_type = encoder_type self.encoder_type = encoder_type
# embedding layer # embedding layer
self.emb = nn.Embedding(num_chars, hidden_channels) self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
# init encoder module # init encoder module
if encoder_type.lower() == "rel_pos_transformer": if encoder_type.lower() == "rel_pos_transformer":
if use_prenet: if use_prenet:

View File

@ -88,7 +88,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
# relative positional encoding layers # relative positional encoding layers
if rel_attn_window_size is not None: if rel_attn_window_size is not None:
n_heads_rel = 1 if heads_share else num_heads n_heads_rel = 1 if heads_share else num_heads
rel_stddev = self.k_channels ** -0.5 rel_stddev = self.k_channels**-0.5
emb_rel_k = nn.Parameter( emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
) )
@ -235,7 +235,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
batch, heads, length, _ = x.size() batch, heads, length, _ = x.size()
# padd along column # padd along column
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape # add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]

View File

@ -218,7 +218,7 @@ class GuidedAttentionLoss(torch.nn.Module):
def _make_ga_mask(ilen, olen, sigma): def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float() grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))) return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
@staticmethod @staticmethod
def _make_masks(ilens, olens): def _make_masks(ilens, olens):
@ -665,7 +665,7 @@ class VitsDiscriminatorLoss(nn.Module):
dr = dr.float() dr = dr.float()
dg = dg.float() dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2) real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg ** 2) fake_loss = torch.mean(dg**2)
loss += real_loss + fake_loss loss += real_loss + fake_loss
real_losses.append(real_loss.item()) real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item()) fake_losses.append(fake_loss.item())

View File

@ -141,7 +141,7 @@ class MultiHeadAttention(nn.Module):
# score = softmax(QK^T / (d_k ** 0.5)) # score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim ** 0.5) scores = scores / (self.key_dim**0.5)
scores = F.softmax(scores, dim=3) scores = F.softmax(scores, dim=3)
# out = score * V # out = score * V

View File

@ -57,7 +57,7 @@ class TextEncoder(nn.Module):
self.emb = nn.Embedding(n_vocab, hidden_channels) self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
if language_emb_dim: if language_emb_dim:
hidden_channels += language_emb_dim hidden_channels += language_emb_dim

View File

@ -33,7 +33,7 @@ class DilatedDepthSeparableConv(nn.Module):
self.norms_1 = nn.ModuleList() self.norms_1 = nn.ModuleList()
self.norms_2 = nn.ModuleList() self.norms_2 = nn.ModuleList()
for i in range(num_layers): for i in range(num_layers):
dilation = kernel_size ** i dilation = kernel_size**i
padding = (kernel_size * dilation - dilation) // 2 padding = (kernel_size * dilation - dilation) // 2
self.convs_sep.append( self.convs_sep.append(
nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding) nn.Conv1d(channels, channels, kernel_size, groups=channels, dilation=dilation, padding=padding)
@ -264,7 +264,7 @@ class StochasticDurationPredictor(nn.Module):
# posterior encoder - neg log likelihood # posterior encoder - neg log likelihood
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]) logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
nll_posterior_encoder = ( nll_posterior_encoder = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (noise ** 2)) * x_mask, [1, 2]) - logdet_tot_q torch.sum(-0.5 * (math.log(2 * math.pi) + (noise**2)) * x_mask, [1, 2]) - logdet_tot_q
) )
z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask z0 = torch.log(torch.clamp_min(z0, 1e-5)) * x_mask
@ -279,7 +279,7 @@ class StochasticDurationPredictor(nn.Module):
z = torch.flip(z, [1]) z = torch.flip(z, [1])
# flow layers - neg log likelihood # flow layers - neg log likelihood
nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z ** 2)) * x_mask, [1, 2]) - logdet_tot nll_flow_layers = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll_flow_layers + nll_posterior_encoder return nll_flow_layers + nll_posterior_encoder
flows = list(reversed(self.flows)) flows = list(reversed(self.flows))

View File

@ -206,9 +206,9 @@ class GlowTTS(BaseTTS):
with torch.no_grad(): with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale) o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
@ -255,9 +255,9 @@ class GlowTTS(BaseTTS):
# find the alignment path between z and encoder output # find the alignment path between z and encoder output
o_scale = torch.exp(-2 * o_log_scale) o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()

View File

@ -4,7 +4,6 @@ from itertools import chain
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
import torchaudio import torchaudio
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
@ -591,9 +590,9 @@ class Vits(BaseTTS):
with torch.no_grad(): with torch.no_grad():
o_scale = torch.exp(-2 * logs_p) o_scale = torch.exp(-2 * logs_p)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1] logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p ** 2)]) logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p]) logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
logp4 = torch.sum(-0.5 * (m_p ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp2 + logp3 + logp1 + logp4 logp = logp2 + logp3 + logp1 + logp4
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
@ -692,10 +691,17 @@ class Vits(BaseTTS):
if self.args.use_sdp: if self.args.use_sdp:
logw = self.duration_predictor( logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb x,
x_mask,
g=g if self.args.condition_dp_on_speaker else None,
reverse=True,
noise_scale=self.inference_noise_scale_dp,
lang_emb=lang_emb,
) )
else: else:
logw = self.duration_predictor(x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb) logw = self.duration_predictor(
x, x_mask, g=g if self.args.condition_dp_on_speaker else None, lang_emb=lang_emb
)
w = torch.exp(logw) * x_mask * self.length_scale w = torch.exp(logw) * x_mask * self.length_scale
w_ceil = torch.ceil(w) w_ceil = torch.ceil(w)

View File

@ -113,7 +113,7 @@ def _set_file_path(path):
def get_language_weighted_sampler(items: list): def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items]) language_names = np.array([item["language"] for item in items])
unique_language_names = np.unique(language_names).tolist() unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names] language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])

View File

@ -118,7 +118,7 @@ class SpeakerManager:
Returns: Returns:
Tuple[Dict, int]: speaker IDs and number of speakers. Tuple[Dict, int]: speaker IDs and number of speakers.
""" """
speakers = sorted({item[2] for item in items}) speakers = sorted({item["speaker_name"] for item in items})
speaker_ids = {name: i for i, name in enumerate(speakers)} speaker_ids = {name: i for i, name in enumerate(speakers)}
num_speakers = len(speaker_ids) num_speakers = len(speaker_ids)
return speaker_ids, num_speakers return speaker_ids, num_speakers
@ -414,7 +414,7 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
def get_speaker_weighted_sampler(items: list): def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items]) speaker_names = np.array([item["speaker_name"] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist() unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])

View File

@ -8,7 +8,7 @@ from torch.autograd import Variable
def gaussian(window_size, sigma): def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) for x in range(window_size)])
return gauss / gauss.sum() return gauss / gauss.sum()
@ -33,8 +33,8 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2 C1 = 0.01**2
C2 = 0.03 ** 2 C2 = 0.03**2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

View File

@ -142,10 +142,10 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
) )
M = o[:, :, :, 0] M = o[:, :, :, 0]
P = o[:, :, :, 1] P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
if self.power is not None: if self.power is not None:
S = S ** self.power S = S**self.power
if self.use_mel: if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S) S = torch.matmul(self.mel_basis.to(x), S)
@ -634,8 +634,8 @@ class AudioProcessor(object):
S = self._db_to_amp(S) S = self._db_to_amp(S)
# Reconstruct phase # Reconstruct phase
if self.preemphasis != 0: if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S ** self.power) return self._griffin_lim(S**self.power)
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" """Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
@ -643,8 +643,8 @@ class AudioProcessor(object):
S = self._db_to_amp(D) S = self._db_to_amp(D)
S = self._mel_to_linear(S) # Convert back to linear S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0: if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S ** self.power) return self._griffin_lim(S**self.power)
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
"""Convert a full scale linear spectrogram output of a network to a melspectrogram. """Convert a full scale linear spectrogram output of a network to a melspectrogram.
@ -781,7 +781,7 @@ class AudioProcessor(object):
@staticmethod @staticmethod
def _rms_norm(wav, db_level=-27): def _rms_norm(wav, db_level=-27):
r = 10 ** (db_level / 20) r = 10 ** (db_level / 20)
a = np.sqrt((len(wav) * (r ** 2)) / np.sum(wav ** 2)) a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2))
return wav * a return wav * a
def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray:
@ -853,7 +853,7 @@ class AudioProcessor(object):
@staticmethod @staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2 ** qc - 1 mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0) # wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels. # Quantize signal to the specified number of levels.
@ -865,13 +865,13 @@ class AudioProcessor(object):
@staticmethod @staticmethod
def mulaw_decode(wav, qc): def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values.""" """Recovers waveform from quantized values."""
mu = 2 ** qc - 1 mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x return x
@staticmethod @staticmethod
def encode_16bits(x): def encode_16bits(x):
return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
@staticmethod @staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray: def quantize(x: np.ndarray, bits: int) -> np.ndarray:
@ -884,12 +884,12 @@ class AudioProcessor(object):
Returns: Returns:
np.ndarray: Quantized waveform. np.ndarray: Quantized waveform.
""" """
return (x + 1.0) * (2 ** bits - 1) / 2 return (x + 1.0) * (2**bits - 1) / 2
@staticmethod @staticmethod
def dequantize(x, bits): def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits.""" """Dequantize a waveform from the given number of bits."""
return 2 * x / (2 ** bits - 1) - 1 return 2 * x / (2**bits - 1) - 1
def _log(x, base): def _log(x, base):

View File

@ -128,7 +128,7 @@ def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") ->
while True: while True:
# Read by chunk to avoid filling memory # Read by chunk to avoid filling memory
chunk = file_obj.read(1024 ** 2) chunk = file_obj.read(1024**2)
if not chunk: if not chunk:
break break
hash_func.update(chunk) hash_func.update(chunk)

View File

@ -39,7 +39,7 @@ class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def get_lr(self): def get_lr(self):
step = max(self.last_epoch, 1) step = max(self.last_epoch, 1)
return [ return [
base_lr * self.warmup_steps ** 0.5 * min(step * self.warmup_steps ** -1.5, step ** -0.5) base_lr * self.warmup_steps**0.5 * min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs for base_lr in self.base_lrs
] ]
@ -63,7 +63,7 @@ def lr_decay(init_lr, global_step, warmup_steps):
It is only being used by the Speaker Encoder trainer.""" It is only being used by the Speaker Encoder trainer."""
warmup_steps = float(warmup_steps) warmup_steps = float(warmup_steps)
step = global_step + 1.0 step = global_step + 1.0
lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5) lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, step**-0.5)
return lr return lr

View File

@ -127,5 +127,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html lr_scheduler_gen: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1})
lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html lr_scheduler_disc: str = "StepLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}) lr_scheduler_disc_params: dict = field(
default_factory=lambda: {"gamma": 0.5, "step_size": 200000, "last_epoch": -1}
)
scheduler_after_epoch: bool = False scheduler_after_epoch: bool = False

View File

@ -111,7 +111,7 @@ class WaveRNNDataset(Dataset):
elif isinstance(self.mode, int): elif isinstance(self.mode, int):
coarse = np.stack(coarse).astype(np.int64) coarse = np.stack(coarse).astype(np.int64)
coarse = torch.LongTensor(coarse) coarse = torch.LongTensor(coarse)
x_input = 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0 x_input = 2 * coarse[:, : self.seq_len].float() / (2**self.mode - 1.0) - 1.0
y_coarse = coarse[:, 1:] y_coarse = coarse[:, 1:]
mels = torch.FloatTensor(mels) mels = torch.FloatTensor(mels)
return x_input, mels, y_coarse return x_input, mels, y_coarse

View File

@ -126,9 +126,9 @@ class LVCBlock(torch.nn.Module):
) )
for i in range(conv_layers): for i in range(conv_layers):
padding = (3 ** i) * int((conv_kernel_size - 1) / 2) padding = (3**i) * int((conv_kernel_size - 1) / 2)
conv = torch.nn.Conv1d( conv = torch.nn.Conv1d(
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3**i
) )
self.convs.append(conv) self.convs.append(conv)

View File

@ -12,7 +12,7 @@ class ResidualStack(nn.Module):
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
for idx in range(num_res_blocks): for idx in range(num_res_blocks):
layer_kernel_size = kernel_size layer_kernel_size = kernel_size
layer_dilation = layer_kernel_size ** idx layer_dilation = layer_kernel_size**idx
layer_padding = base_padding * layer_dilation layer_padding = base_padding * layer_dilation
self.blocks += [ self.blocks += [
nn.Sequential( nn.Sequential(

View File

@ -72,6 +72,6 @@ class ResidualBlock(torch.nn.Module):
s = self.conv1x1_skip(x) s = self.conv1x1_skip(x)
# for residual connection # for residual connection
x = (self.conv1x1_out(x) + residual) * (0.5 ** 2) x = (self.conv1x1_out(x) + residual) * (0.5**2)
return x, s return x, s

View File

@ -207,7 +207,7 @@ class HifiganGenerator(torch.nn.Module):
self.ups.append( self.ups.append(
weight_norm( weight_norm(
ConvTranspose1d( ConvTranspose1d(
upsample_initial_channel // (2 ** i), upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)), upsample_initial_channel // (2 ** (i + 1)),
k, k,
u, u,

View File

@ -36,7 +36,7 @@ class MelganGenerator(nn.Module):
# upsampling layers and residual stacks # upsampling layers and residual stacks
for idx, upsample_factor in enumerate(upsample_factors): for idx, upsample_factor in enumerate(upsample_factors):
layer_in_channels = base_channels // (2 ** idx) layer_in_channels = base_channels // (2**idx)
layer_out_channels = base_channels // (2 ** (idx + 1)) layer_out_channels = base_channels // (2 ** (idx + 1))
layer_filter_size = upsample_factor * 2 layer_filter_size = upsample_factor * 2
layer_stride = upsample_factor layer_stride = upsample_factor

View File

@ -35,7 +35,7 @@ class ParallelWaveganDiscriminator(nn.Module):
if i == 0: if i == 0:
dilation = 1 dilation = 1
else: else:
dilation = i if dilation_factor == 1 else dilation_factor ** i dilation = i if dilation_factor == 1 else dilation_factor**i
conv_in_channels = conv_channels conv_in_channels = conv_channels
padding = (kernel_size - 1) // 2 * dilation padding = (kernel_size - 1) // 2 * dilation
conv_layer = [ conv_layer = [

View File

@ -142,7 +142,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)
@staticmethod @staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0 assert layers % stacks == 0
layers_per_cycle = layers // stacks layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)] dilations = [dilation(i % layers_per_cycle) for i in range(layers)]

View File

@ -130,7 +130,7 @@ class UnivnetGenerator(torch.nn.Module):
self.apply(_apply_weight_norm) self.apply(_apply_weight_norm)
@staticmethod @staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0 assert layers % stacks == 0
layers_per_cycle = layers // stacks layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)] dilations = [dilation(i % layers_per_cycle) for i in range(layers)]

View File

@ -153,7 +153,7 @@ class Wavegrad(BaseVocoder):
noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a)
noise_scale = noise_scale.unsqueeze(1) noise_scale = noise_scale.unsqueeze(1)
noise = torch.randn_like(y_0) noise = torch.randn_like(y_0)
noisy_audio = noise_scale * y_0 + (1.0 - noise_scale ** 2) ** 0.5 * noise noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2) ** 0.5 * noise
return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0]
def compute_noise_level(self, beta): def compute_noise_level(self, beta):
@ -161,8 +161,8 @@ class Wavegrad(BaseVocoder):
self.num_steps = len(beta) self.num_steps = len(beta)
alpha = 1 - beta alpha = 1 - beta
alpha_hat = np.cumprod(alpha) alpha_hat = np.cumprod(alpha)
noise_level = np.concatenate([[1.0], alpha_hat ** 0.5], axis=0) noise_level = np.concatenate([[1.0], alpha_hat**0.5], axis=0)
noise_level = alpha_hat ** 0.5 noise_level = alpha_hat**0.5
# pylint: disable=not-callable # pylint: disable=not-callable
self.beta = torch.tensor(beta.astype(np.float32)) self.beta = torch.tensor(beta.astype(np.float32))
@ -170,7 +170,7 @@ class Wavegrad(BaseVocoder):
self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32))
self.noise_level = torch.tensor(noise_level.astype(np.float32)) self.noise_level = torch.tensor(noise_level.astype(np.float32))
self.c1 = 1 / self.alpha ** 0.5 self.c1 = 1 / self.alpha**0.5
self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5
self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5

View File

@ -225,7 +225,7 @@ class Wavernn(BaseVocoder):
super().__init__(config) super().__init__(config)
if isinstance(self.args.mode, int): if isinstance(self.args.mode, int):
self.n_classes = 2 ** self.args.mode self.n_classes = 2**self.args.mode
elif self.args.mode == "mold": elif self.args.mode == "mold":
self.n_classes = 3 * 10 self.n_classes = 3 * 10
elif self.args.mode == "gauss": elif self.args.mode == "gauss":

View File

@ -5,13 +5,13 @@ from tests import get_tests_input_path
from TTS.tts.datasets.formatters import common_voice from TTS.tts.datasets.formatters import common_voice
class TestPreprocessors(unittest.TestCase): class TestTTSFormatters(unittest.TestCase):
def test_common_voice_preprocessor(self): # pylint: disable=no-self-use def test_common_voice_preprocessor(self): # pylint: disable=no-self-use
root_path = get_tests_input_path() root_path = get_tests_input_path()
meta_file = "common_voice.tsv" meta_file = "common_voice.tsv"
items = common_voice(root_path, meta_file) items = common_voice(root_path, meta_file)
assert items[0][0] == "The applicants are invited for coffee and visa is given immediately." assert items[0]["text"] == "The applicants are invited for coffee and visa is given immediately."
assert items[0][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav") assert items[0]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_20005954.wav")
assert items[-1][0] == "Competition for limited resources has also resulted in some local conflicts." assert items[-1]["text"] == "Competition for limited resources has also resulted in some local conflicts."
assert items[-1][1] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav") assert items[-1]["audio_file"] == os.path.join(get_tests_input_path(), "clips", "common_voice_en_19737074.wav")

View File

@ -46,6 +46,6 @@ def test_wavernn():
config.model_args.mode = 4 config.model_args.mode = 4
model = Wavernn(config) model = Wavernn(config)
output = model(dummy_x, dummy_m) output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 2 ** 4)), output.shape assert np.all(output.shape == (2, 1280, 2**4)), output.shape
output = model.inference(dummy_y, True, 5500, 550) output = model.inference(dummy_y, True, 5500, 550)
assert np.all(output.shape == (256 * (y_size - 1),)) assert np.all(output.shape == (256 * (y_size - 1),))