make style

This commit is contained in:
WeberJulian 2021-11-02 17:31:14 +01:00 committed by Eren Gölge
parent e22f7a2aca
commit 1472b6df49
15 changed files with 158 additions and 87 deletions

View File

@ -1,14 +1,15 @@
"""Find all the unique characters in a dataset""" """Find all the unique characters in a dataset"""
import argparse import argparse
import multiprocessing
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
import numpy
from tqdm.contrib.concurrent import process_map
from TTS.config import load_config from TTS.config import load_config
from TTS.tts.datasets import load_meta_data from TTS.tts.datasets import load_meta_data
import numpy
import multiprocessing
from TTS.tts.utils.text import text2phone from TTS.tts.utils.text import text2phone
from tqdm.contrib.concurrent import process_map
def compute_phonemes(item): def compute_phonemes(item):
try: try:
@ -19,6 +20,7 @@ def compute_phonemes(item):
return [] return []
return list(set(ph)) return list(set(ph))
def main(): def main():
global c global c
# pylint: disable=bad-option-value # pylint: disable=bad-option-value
@ -51,8 +53,6 @@ def main():
phones_force_lower = [c.lower() for c in phones] phones_force_lower = [c.lower() for c in phones]
phones_force_lower = set(phones_force_lower) phones_force_lower = set(phones_force_lower)
print(f" > Number of unique phonemes: {len(phones)}") print(f" > Number of unique phonemes: {len(phones)}")
print(f" > Unique phonemes: {''.join(sorted(phones))}") print(f" > Unique phonemes: {''.join(sorted(phones))}")
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}") print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")

View File

@ -1,26 +1,27 @@
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py # This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
import os
import tqdm
import glob
import argparse import argparse
import pathlib
import collections import collections
import contextlib import contextlib
import glob
import multiprocessing
import os
import pathlib
import sys import sys
import wave import wave
from itertools import chain
import numpy as np import numpy as np
import tqdm
import webrtcvad import webrtcvad
from tqdm.contrib.concurrent import process_map from tqdm.contrib.concurrent import process_map
import multiprocessing
from itertools import chain
def read_wave(path): def read_wave(path):
"""Reads a .wav file. """Reads a .wav file.
Takes the path, and returns (PCM audio data, sample rate). Takes the path, and returns (PCM audio data, sample rate).
""" """
with contextlib.closing(wave.open(path, 'rb')) as wf: with contextlib.closing(wave.open(path, "rb")) as wf:
num_channels = wf.getnchannels() num_channels = wf.getnchannels()
assert num_channels == 1 assert num_channels == 1
sample_width = wf.getsampwidth() sample_width = wf.getsampwidth()
@ -36,7 +37,7 @@ def write_wave(path, audio, sample_rate):
Takes path, PCM audio data, and sample rate. Takes path, PCM audio data, and sample rate.
""" """
with contextlib.closing(wave.open(path, 'wb')) as wf: with contextlib.closing(wave.open(path, "wb")) as wf:
wf.setnchannels(1) wf.setnchannels(1)
wf.setsampwidth(2) wf.setsampwidth(2)
wf.setframerate(sample_rate) wf.setframerate(sample_rate)
@ -45,6 +46,7 @@ def write_wave(path, audio, sample_rate):
class Frame(object): class Frame(object):
"""Represents a "frame" of audio data.""" """Represents a "frame" of audio data."""
def __init__(self, bytes, timestamp, duration): def __init__(self, bytes, timestamp, duration):
self.bytes = bytes self.bytes = bytes
self.timestamp = timestamp self.timestamp = timestamp
@ -64,13 +66,12 @@ def frame_generator(frame_duration_ms, audio, sample_rate):
timestamp = 0.0 timestamp = 0.0
duration = (float(n) / sample_rate) / 2.0 duration = (float(n) / sample_rate) / 2.0
while offset + n < len(audio): while offset + n < len(audio):
yield Frame(audio[offset:offset + n], timestamp, duration) yield Frame(audio[offset : offset + n], timestamp, duration)
timestamp += duration timestamp += duration
offset += n offset += n
def vad_collector(sample_rate, frame_duration_ms, def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
padding_duration_ms, vad, frames):
"""Filters out non-voiced audio frames. """Filters out non-voiced audio frames.
Given a webrtcvad.Vad and a source of audio frames, yields only Given a webrtcvad.Vad and a source of audio frames, yields only
@ -133,25 +134,26 @@ def vad_collector(sample_rate, frame_duration_ms,
# unvoiced, then enter NOTTRIGGERED and yield whatever # unvoiced, then enter NOTTRIGGERED and yield whatever
# audio we've collected. # audio we've collected.
if num_unvoiced > 0.9 * ring_buffer.maxlen: if num_unvoiced > 0.9 * ring_buffer.maxlen:
#sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration)) # sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
triggered = False triggered = False
yield b''.join([f.bytes for f in voiced_frames]) yield b"".join([f.bytes for f in voiced_frames])
ring_buffer.clear() ring_buffer.clear()
voiced_frames = [] voiced_frames = []
# If we have any leftover voiced audio when we run out of input, # If we have any leftover voiced audio when we run out of input,
# yield it. # yield it.
if voiced_frames: if voiced_frames:
yield b''.join([f.bytes for f in voiced_frames]) yield b"".join([f.bytes for f in voiced_frames])
def remove_silence(filepath): def remove_silence(filepath):
filename = os.path.basename(filepath) filename = os.path.basename(filepath)
output_path = filepath.replace(os.path.join(args.input_dir, ''),os.path.join(args.output_dir, '')) output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
# ignore if the file exists # ignore if the file exists
if os.path.exists(output_path) and not args.force: if os.path.exists(output_path) and not args.force:
return False return False
# create all directory structure # create all directory structure
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True) pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
padding_duration_ms = 300 # default 300 padding_duration_ms = 300 # default 300
audio, sample_rate = read_wave(filepath) audio, sample_rate = read_wave(filepath)
vad = webrtcvad.Vad(int(args.aggressiveness)) vad = webrtcvad.Vad(int(args.aggressiveness))
frames = frame_generator(30, audio, sample_rate) frames = frame_generator(30, audio, sample_rate)
@ -180,6 +182,7 @@ def remove_silence(filepath):
# if fail to remove silence just write the file # if fail to remove silence just write the file
write_wave(output_path, audio, sample_rate) write_wave(output_path, audio, sample_rate)
def preprocess_audios(): def preprocess_audios():
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True)) files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
print("> Number of files: ", len(files)) print("> Number of files: ", len(files))
@ -193,21 +196,31 @@ def preprocess_audios():
else: else:
print("> No files Found !") print("> No files Found !")
if __name__ == "__main__": if __name__ == "__main__":
""" """
usage usage
python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2 python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_dir', type=str, default='../VCTK-Corpus', parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
help='Dataset root dir') parser.add_argument(
parser.add_argument('-o', '--output_dir', type=str, default='../VCTK-Corpus-removed-silence', "-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
help='Output Dataset dir') )
parser.add_argument('-f', '--force', type=bool, default=True, parser.add_argument("-f", "--force", type=bool, default=True, help="Force the replace of exists files")
help='Force the replace of exists files') parser.add_argument(
parser.add_argument('-g', '--glob', type=str, default='**/*.wav', "-g",
help='path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav') "--glob",
parser.add_argument('-a', '--aggressiveness', type=int, default=2, type=str,
help='set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.') default="**/*.wav",
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
)
parser.add_argument(
"-a",
"--aggressiveness",
type=int,
default=2,
help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
)
args = parser.parse_args() args = parser.parse_args()
preprocess_audios() preprocess_audios()

View File

@ -5,20 +5,20 @@ import torch.nn as nn
from TTS.utils.io import load_fsspec from TTS.utils.io import load_fsspec
class PreEmphasis(torch.nn.Module): class PreEmphasis(torch.nn.Module):
def __init__(self, coefficient=0.97): def __init__(self, coefficient=0.97):
super().__init__() super().__init__()
self.coefficient = coefficient self.coefficient = coefficient
self.register_buffer( self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0)
)
def forward(self, x): def forward(self, x):
assert len(x.size()) == 2 assert len(x.size()) == 2
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect') x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
return torch.nn.functional.conv1d(x, self.filter).squeeze(1) return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
class SELayer(nn.Module): class SELayer(nn.Module):
def __init__(self, channel, reduction=8): def __init__(self, channel, reduction=8):
super(SELayer, self).__init__() super(SELayer, self).__init__()
@ -110,8 +110,15 @@ class ResNetSpeakerEncoder(nn.Module):
if self.use_torch_spec: if self.use_torch_spec:
self.torch_spec = torch.nn.Sequential( self.torch_spec = torch.nn.Sequential(
PreEmphasis(audio_config["preemphasis"]), PreEmphasis(audio_config["preemphasis"]),
torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"]) torchaudio.transforms.MelSpectrogram(
) sample_rate=audio_config["sample_rate"],
n_fft=audio_config["fft_size"],
win_length=audio_config["win_length"],
hop_length=audio_config["hop_length"],
window_fn=torch.hamming_window,
n_mels=audio_config["num_mels"],
),
)
else: else:
self.torch_spec = None self.torch_spec = None
@ -213,7 +220,7 @@ class ResNetSpeakerEncoder(nn.Module):
""" """
# map to the waveform size # map to the waveform size
if self.use_torch_spec: if self.use_torch_spec:
num_frames = num_frames * self.audio_config['hop_length'] num_frames = num_frames * self.audio_config["hop_length"]
max_len = x.shape[1] max_len = x.shape[1]

View File

@ -179,10 +179,12 @@ def setup_model(c):
c.model_params["num_lstm_layers"], c.model_params["num_lstm_layers"],
) )
elif c.model_params["model_name"].lower() == "resnet": elif c.model_params["model_name"].lower() == "resnet":
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"], model = ResNetSpeakerEncoder(
input_dim=c.model_params["input_dim"],
proj_dim=c.model_params["proj_dim"],
log_input=c.model_params.get("log_input", False), log_input=c.model_params.get("log_input", False),
use_torch_spec=c.model_params.get("use_torch_spec", False), use_torch_spec=c.model_params.get("use_torch_spec", False),
audio_config=c.audio audio_config=c.audio,
) )
return model return model

View File

@ -265,7 +265,9 @@ class Trainer:
config = self.config.model_args if hasattr(self.config, "model_args") else self.config config = self.config.model_args if hasattr(self.config, "model_args") else self.config
# save speakers json # save speakers json
if config.use_language_embedding and self.model.language_manager.num_languages > 1: if config.use_language_embedding and self.model.language_manager.num_languages > 1:
self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json")) self.model.language_manager.save_language_ids_to_file(
os.path.join(self.output_path, "language_ids.json")
)
if hasattr(self.config, "model_args"): if hasattr(self.config, "model_args"):
self.config.model_args["num_languages"] = self.model.language_manager.num_languages self.config.model_args["num_languages"] = self.model.language_manager.num_languages
else: else:

View File

@ -542,6 +542,7 @@ class TTSDataset(Dataset):
) )
) )
class PitchExtractor: class PitchExtractor:
"""Pitch Extractor for computing F0 from wav files. """Pitch Extractor for computing F0 from wav files.
Args: Args:

View File

@ -304,7 +304,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None):
return items return items
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
items = [] items = []
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)

View File

@ -602,7 +602,7 @@ class VitsGeneratorLoss(nn.Module):
fine_tuning_mode=0, fine_tuning_mode=0,
use_speaker_encoder_as_loss=False, use_speaker_encoder_as_loss=False,
gt_spk_emb=None, gt_spk_emb=None,
syn_spk_emb=None syn_spk_emb=None,
): ):
""" """
Shapes: Shapes:
@ -638,7 +638,9 @@ class VitsGeneratorLoss(nn.Module):
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
if use_speaker_encoder_as_loss: if use_speaker_encoder_as_loss:
loss_se = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha loss_se = (
-torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
)
loss += loss_se loss += loss_se
return_dict["loss_spk_encoder"] = loss_se return_dict["loss_spk_encoder"] = loss_se

View File

@ -178,7 +178,14 @@ class StochasticDurationPredictor(nn.Module):
""" """
def __init__( def __init__(
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0, language_emb_dim=None self,
in_channels: int,
hidden_channels: int,
kernel_size: int,
dropout_p: float,
num_flows=4,
cond_channels=0,
language_emb_dim=None,
): ):
super().__init__() super().__init__()

View File

@ -246,7 +246,9 @@ class BaseTTS(BaseModel):
# setup multi-speaker attributes # setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None: if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"): if hasattr(config, "model_args"):
speaker_id_mapping = self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None speaker_id_mapping = (
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
)
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file config.use_d_vector_file = config.model_args.use_d_vector_file
else: else:
@ -262,7 +264,9 @@ class BaseTTS(BaseModel):
custom_symbols = self.make_symbols(self.config) custom_symbols = self.make_symbols(self.config)
if hasattr(self, "language_manager"): if hasattr(self, "language_manager"):
language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)
else: else:
language_id_mapping = None language_id_mapping = None

View File

@ -229,7 +229,6 @@ class VitsArgs(Coqpit):
freeze_waveform_decoder: bool = False freeze_waveform_decoder: bool = False
class Vits(BaseTTS): class Vits(BaseTTS):
"""VITS TTS model """VITS TTS model
@ -306,7 +305,7 @@ class Vits(BaseTTS):
args.num_layers_text_encoder, args.num_layers_text_encoder,
args.kernel_size_text_encoder, args.kernel_size_text_encoder,
args.dropout_p_text_encoder, args.dropout_p_text_encoder,
language_emb_dim=self.embedded_language_dim language_emb_dim=self.embedded_language_dim,
) )
self.posterior_encoder = PosteriorEncoder( self.posterior_encoder = PosteriorEncoder(
@ -389,16 +388,26 @@ class Vits(BaseTTS):
# TODO: make this a function # TODO: make this a function
if config.use_speaker_encoder_as_loss: if config.use_speaker_encoder_as_loss:
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path: if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!") raise RuntimeError(
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path) " [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
)
self.speaker_manager.init_speaker_encoder(
config.speaker_encoder_model_path, config.speaker_encoder_config_path
)
self.speaker_encoder = self.speaker_manager.speaker_encoder.train() self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
for param in self.speaker_encoder.parameters(): for param in self.speaker_encoder.parameters():
param.requires_grad = False param.requires_grad = False
print(" > External Speaker Encoder Loaded !!") print(" > External Speaker Encoder Loaded !!")
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]: if (
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"]) hasattr(self.speaker_encoder, "audio_config")
and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
):
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_encoder.audio_config["sample_rate"],
)
else: else:
self.audio_transform = None self.audio_transform = None
else: else:
@ -529,7 +538,13 @@ class Vits(BaseTTS):
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name] language_id = self.language_manager.language_id_mapping[language_name]
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id} return {
"text": text,
"speaker_id": speaker_id,
"style_wav": style_wav,
"d_vector": d_vector,
"language_id": language_id,
}
def forward( def forward(
self, self,
@ -567,7 +582,7 @@ class Vits(BaseTTS):
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# language embedding # language embedding
lang_emb=None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1) lang_emb = self.emb_l(lid).unsqueeze(-1)
@ -621,9 +636,9 @@ class Vits(BaseTTS):
o = self.waveform_decoder(z_slice, g=g) o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment( wav_seg = segment(
waveform.transpose(1, 2), waveform.transpose(1, 2),
slice_ids * self.config.audio.hop_length, slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length,
) )
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None: if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
@ -653,7 +668,7 @@ class Vits(BaseTTS):
"logs_q": logs_q, "logs_q": logs_q,
"waveform_seg": wav_seg, "waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb, "gt_spk_emb": gt_spk_emb,
"syn_spk_emb": syn_spk_emb "syn_spk_emb": syn_spk_emb,
} }
) )
return outputs return outputs
@ -695,7 +710,7 @@ class Vits(BaseTTS):
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
# language embedding # language embedding
lang_emb=None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1) lang_emb = self.emb_l(lid).unsqueeze(-1)
@ -737,9 +752,9 @@ class Vits(BaseTTS):
o = self.waveform_decoder(z_slice, g=g) o = self.waveform_decoder(z_slice, g=g)
wav_seg = segment( wav_seg = segment(
waveform.transpose(1, 2), waveform.transpose(1, 2),
slice_ids * self.config.audio.hop_length, slice_ids * self.config.audio.hop_length,
self.args.spec_segment_size * self.config.audio.hop_length, self.args.spec_segment_size * self.config.audio.hop_length,
) )
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None: if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
@ -770,7 +785,7 @@ class Vits(BaseTTS):
"logs_q": logs_q, "logs_q": logs_q,
"waveform_seg": wav_seg, "waveform_seg": wav_seg,
"gt_spk_emb": gt_spk_emb, "gt_spk_emb": gt_spk_emb,
"syn_spk_emb": syn_spk_emb "syn_spk_emb": syn_spk_emb,
} }
) )
return outputs return outputs
@ -790,14 +805,16 @@ class Vits(BaseTTS):
g = self.emb_g(sid).unsqueeze(-1) g = self.emb_g(sid).unsqueeze(-1)
# language embedding # language embedding
lang_emb=None lang_emb = None
if self.args.use_language_embedding and lid is not None: if self.args.use_language_embedding and lid is not None:
lang_emb = self.emb_l(lid).unsqueeze(-1) lang_emb = self.emb_l(lid).unsqueeze(-1)
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
if self.args.use_sdp: if self.args.use_sdp:
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb) logw = self.duration_predictor(
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
)
else: else:
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb) logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
@ -866,7 +883,7 @@ class Vits(BaseTTS):
for param in self.text_encoder.parameters(): for param in self.text_encoder.parameters():
param.requires_grad = False param.requires_grad = False
if hasattr(self, 'emb_l'): if hasattr(self, "emb_l"):
for param in self.emb_l.parameters(): for param in self.emb_l.parameters():
param.requires_grad = False param.requires_grad = False
@ -932,7 +949,7 @@ class Vits(BaseTTS):
with autocast(enabled=False): # use float32 for the criterion with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion[optimizer_idx]( loss_dict = criterion[optimizer_idx](
waveform_hat=outputs["model_outputs"].float(), waveform_hat=outputs["model_outputs"].float(),
waveform= outputs["waveform_seg"].float(), waveform=outputs["waveform_seg"].float(),
z_p=outputs["z_p"].float(), z_p=outputs["z_p"].float(),
logs_q=outputs["logs_q"].float(), logs_q=outputs["logs_q"].float(),
m_p=outputs["m_p"].float(), m_p=outputs["m_p"].float(),
@ -945,7 +962,7 @@ class Vits(BaseTTS):
fine_tuning_mode=self.args.fine_tuning_mode, fine_tuning_mode=self.args.fine_tuning_mode,
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss, use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
gt_spk_emb=outputs["gt_spk_emb"], gt_spk_emb=outputs["gt_spk_emb"],
syn_spk_emb=outputs["syn_spk_emb"] syn_spk_emb=outputs["syn_spk_emb"],
) )
# ignore duration loss if fine tuning mode is on # ignore duration loss if fine tuning mode is on
if not self.args.fine_tuning_mode: if not self.args.fine_tuning_mode:

View File

@ -1,13 +1,14 @@
import os
import json import json
import torch import os
from typing import Dict, List, Tuple
import fsspec import fsspec
import numpy as np import numpy as np
from typing import Dict, Tuple, List import torch
from coqpit import Coqpit from coqpit import Coqpit
from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.data.sampler import WeightedRandomSampler
class LanguageManager: class LanguageManager:
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information """Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
in a way that can be queried by language. in a way that can be queried by language.
@ -20,7 +21,9 @@ class LanguageManager:
>>> manager = LanguageManager(language_id_file_path=language_id_file_path) >>> manager = LanguageManager(language_id_file_path=language_id_file_path)
>>> language_id_mapper = manager.language_ids >>> language_id_mapper = manager.language_ids
""" """
language_id_mapping: Dict = {} language_id_mapping: Dict = {}
def __init__( def __init__(
self, self,
language_id_file_path: str = "", language_id_file_path: str = "",
@ -85,6 +88,7 @@ class LanguageManager:
""" """
self._save_json(file_path, self.language_id_mapping) self._save_json(file_path, self.language_id_mapping)
def _set_file_path(path): def _set_file_path(path):
"""Find the language_ids.json under the given path or the above it. """Find the language_ids.json under the given path or the above it.
Intended to band aid the different paths returned in restored and continued training.""" Intended to band aid the different paths returned in restored and continued training."""
@ -97,6 +101,7 @@ def _set_file_path(path):
return path_continue return path_continue
return None return None
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager: def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
"""Initiate a `LanguageManager` instance by the provided config. """Initiate a `LanguageManager` instance by the provided config.
@ -118,7 +123,7 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None)
# restoring language manager from a previous run. # restoring language manager from a previous run.
if language_file: if language_file:
language_manager.set_language_ids_from_file(language_file) language_manager.set_language_ids_from_file(language_file)
if language_manager.num_languages > 0: if language_manager.num_languages > 0:
print( print(
" > Language manager is loaded with {} languages: {}".format( " > Language manager is loaded with {} languages: {}".format(
language_manager.num_languages, ", ".join(language_manager.language_names) language_manager.num_languages, ", ".join(language_manager.language_names)
@ -126,11 +131,12 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None)
) )
return language_manager return language_manager
def get_language_weighted_sampler(items: list): def get_language_weighted_sampler(items: list):
language_names = np.array([item[3] for item in items]) language_names = np.array([item[3] for item in items])
unique_language_names = np.unique(language_names).tolist() unique_language_names = np.unique(language_names).tolist()
language_ids = [unique_language_names.index(l) for l in language_names] language_ids = [unique_language_names.index(l) for l in language_names]
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names]) language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
weight_language = 1. / language_count weight_language = 1.0 / language_count
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double() dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))

View File

@ -432,11 +432,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
speaker_manager.save_speaker_ids_to_file(out_file_path) speaker_manager.save_speaker_ids_to_file(out_file_path)
return speaker_manager return speaker_manager
def get_speaker_weighted_sampler(items: list): def get_speaker_weighted_sampler(items: list):
speaker_names = np.array([item[2] for item in items]) speaker_names = np.array([item[2] for item in items])
unique_speaker_names = np.unique(speaker_names).tolist() unique_speaker_names = np.unique(speaker_names).tolist()
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names] speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names]) speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
weight_speaker = 1. / speaker_count weight_speaker = 1.0 / speaker_count
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double() dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight)) return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))

View File

@ -136,8 +136,9 @@ def phoneme_cleaners(text):
text = collapse_whitespace(text) text = collapse_whitespace(text)
return text return text
def multilingual_cleaners(text): def multilingual_cleaners(text):
'''Pipeline for multilingual text''' """Pipeline for multilingual text"""
text = lowercase(text) text = lowercase(text)
text = replace_symbols(text, lang=None) text = replace_symbols(text, lang=None)
text = remove_aux_symbols(text) text = remove_aux_symbols(text)

View File

@ -3,19 +3,27 @@ import os
import shutil import shutil
from tests import get_device_id, get_tests_output_path, run_cli from tests import get_device_id, get_tests_output_path, run_cli
from TTS.tts.configs.vits_config import VitsConfig
from TTS.config.shared_configs import BaseDatasetConfig from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.configs.vits_config import VitsConfig
config_path = os.path.join(get_tests_output_path(), "test_model_config.json") config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs") output_path = os.path.join(get_tests_output_path(), "train_outputs")
dataset_config1 = BaseDatasetConfig( dataset_config1 = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en" name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en",
) )
dataset_config2 = BaseDatasetConfig( dataset_config2 = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en2" name="ljspeech",
meta_file_train="metadata.csv",
meta_file_val="metadata.csv",
path="tests/data/ljspeech",
language="en2",
) )
config = VitsConfig( config = VitsConfig(