mirror of https://github.com/coqui-ai/TTS.git
make style
This commit is contained in:
parent
e22f7a2aca
commit
1472b6df49
|
@ -1,14 +1,15 @@
|
||||||
"""Find all the unique characters in a dataset"""
|
"""Find all the unique characters in a dataset"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import multiprocessing
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from tqdm.contrib.concurrent import process_map
|
||||||
|
|
||||||
from TTS.config import load_config
|
from TTS.config import load_config
|
||||||
from TTS.tts.datasets import load_meta_data
|
from TTS.tts.datasets import load_meta_data
|
||||||
|
|
||||||
import numpy
|
|
||||||
import multiprocessing
|
|
||||||
from TTS.tts.utils.text import text2phone
|
from TTS.tts.utils.text import text2phone
|
||||||
from tqdm.contrib.concurrent import process_map
|
|
||||||
|
|
||||||
def compute_phonemes(item):
|
def compute_phonemes(item):
|
||||||
try:
|
try:
|
||||||
|
@ -19,6 +20,7 @@ def compute_phonemes(item):
|
||||||
return []
|
return []
|
||||||
return list(set(ph))
|
return list(set(ph))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global c
|
global c
|
||||||
# pylint: disable=bad-option-value
|
# pylint: disable=bad-option-value
|
||||||
|
@ -51,8 +53,6 @@ def main():
|
||||||
phones_force_lower = [c.lower() for c in phones]
|
phones_force_lower = [c.lower() for c in phones]
|
||||||
phones_force_lower = set(phones_force_lower)
|
phones_force_lower = set(phones_force_lower)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print(f" > Number of unique phonemes: {len(phones)}")
|
print(f" > Number of unique phonemes: {len(phones)}")
|
||||||
print(f" > Unique phonemes: {''.join(sorted(phones))}")
|
print(f" > Unique phonemes: {''.join(sorted(phones))}")
|
||||||
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
|
print(f" > Unique lower phonemes: {''.join(sorted(lower_phones))}")
|
||||||
|
|
|
@ -1,26 +1,27 @@
|
||||||
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
# This code is adpated from: https://github.com/wiseman/py-webrtcvad/blob/master/example.py
|
||||||
import os
|
|
||||||
import tqdm
|
|
||||||
import glob
|
|
||||||
import argparse
|
import argparse
|
||||||
import pathlib
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import glob
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
import wave
|
import wave
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import tqdm
|
||||||
import webrtcvad
|
import webrtcvad
|
||||||
from tqdm.contrib.concurrent import process_map
|
from tqdm.contrib.concurrent import process_map
|
||||||
import multiprocessing
|
|
||||||
from itertools import chain
|
|
||||||
|
|
||||||
def read_wave(path):
|
def read_wave(path):
|
||||||
"""Reads a .wav file.
|
"""Reads a .wav file.
|
||||||
|
|
||||||
Takes the path, and returns (PCM audio data, sample rate).
|
Takes the path, and returns (PCM audio data, sample rate).
|
||||||
"""
|
"""
|
||||||
with contextlib.closing(wave.open(path, 'rb')) as wf:
|
with contextlib.closing(wave.open(path, "rb")) as wf:
|
||||||
num_channels = wf.getnchannels()
|
num_channels = wf.getnchannels()
|
||||||
assert num_channels == 1
|
assert num_channels == 1
|
||||||
sample_width = wf.getsampwidth()
|
sample_width = wf.getsampwidth()
|
||||||
|
@ -36,7 +37,7 @@ def write_wave(path, audio, sample_rate):
|
||||||
|
|
||||||
Takes path, PCM audio data, and sample rate.
|
Takes path, PCM audio data, and sample rate.
|
||||||
"""
|
"""
|
||||||
with contextlib.closing(wave.open(path, 'wb')) as wf:
|
with contextlib.closing(wave.open(path, "wb")) as wf:
|
||||||
wf.setnchannels(1)
|
wf.setnchannels(1)
|
||||||
wf.setsampwidth(2)
|
wf.setsampwidth(2)
|
||||||
wf.setframerate(sample_rate)
|
wf.setframerate(sample_rate)
|
||||||
|
@ -45,6 +46,7 @@ def write_wave(path, audio, sample_rate):
|
||||||
|
|
||||||
class Frame(object):
|
class Frame(object):
|
||||||
"""Represents a "frame" of audio data."""
|
"""Represents a "frame" of audio data."""
|
||||||
|
|
||||||
def __init__(self, bytes, timestamp, duration):
|
def __init__(self, bytes, timestamp, duration):
|
||||||
self.bytes = bytes
|
self.bytes = bytes
|
||||||
self.timestamp = timestamp
|
self.timestamp = timestamp
|
||||||
|
@ -64,13 +66,12 @@ def frame_generator(frame_duration_ms, audio, sample_rate):
|
||||||
timestamp = 0.0
|
timestamp = 0.0
|
||||||
duration = (float(n) / sample_rate) / 2.0
|
duration = (float(n) / sample_rate) / 2.0
|
||||||
while offset + n < len(audio):
|
while offset + n < len(audio):
|
||||||
yield Frame(audio[offset:offset + n], timestamp, duration)
|
yield Frame(audio[offset : offset + n], timestamp, duration)
|
||||||
timestamp += duration
|
timestamp += duration
|
||||||
offset += n
|
offset += n
|
||||||
|
|
||||||
|
|
||||||
def vad_collector(sample_rate, frame_duration_ms,
|
def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, frames):
|
||||||
padding_duration_ms, vad, frames):
|
|
||||||
"""Filters out non-voiced audio frames.
|
"""Filters out non-voiced audio frames.
|
||||||
|
|
||||||
Given a webrtcvad.Vad and a source of audio frames, yields only
|
Given a webrtcvad.Vad and a source of audio frames, yields only
|
||||||
|
@ -133,25 +134,26 @@ def vad_collector(sample_rate, frame_duration_ms,
|
||||||
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
# unvoiced, then enter NOTTRIGGERED and yield whatever
|
||||||
# audio we've collected.
|
# audio we've collected.
|
||||||
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
if num_unvoiced > 0.9 * ring_buffer.maxlen:
|
||||||
#sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
|
# sys.stdout.write('-(%s)' % (frame.timestamp + frame.duration))
|
||||||
triggered = False
|
triggered = False
|
||||||
yield b''.join([f.bytes for f in voiced_frames])
|
yield b"".join([f.bytes for f in voiced_frames])
|
||||||
ring_buffer.clear()
|
ring_buffer.clear()
|
||||||
voiced_frames = []
|
voiced_frames = []
|
||||||
# If we have any leftover voiced audio when we run out of input,
|
# If we have any leftover voiced audio when we run out of input,
|
||||||
# yield it.
|
# yield it.
|
||||||
if voiced_frames:
|
if voiced_frames:
|
||||||
yield b''.join([f.bytes for f in voiced_frames])
|
yield b"".join([f.bytes for f in voiced_frames])
|
||||||
|
|
||||||
|
|
||||||
def remove_silence(filepath):
|
def remove_silence(filepath):
|
||||||
filename = os.path.basename(filepath)
|
filename = os.path.basename(filepath)
|
||||||
output_path = filepath.replace(os.path.join(args.input_dir, ''),os.path.join(args.output_dir, ''))
|
output_path = filepath.replace(os.path.join(args.input_dir, ""), os.path.join(args.output_dir, ""))
|
||||||
# ignore if the file exists
|
# ignore if the file exists
|
||||||
if os.path.exists(output_path) and not args.force:
|
if os.path.exists(output_path) and not args.force:
|
||||||
return False
|
return False
|
||||||
# create all directory structure
|
# create all directory structure
|
||||||
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
padding_duration_ms = 300 # default 300
|
padding_duration_ms = 300 # default 300
|
||||||
audio, sample_rate = read_wave(filepath)
|
audio, sample_rate = read_wave(filepath)
|
||||||
vad = webrtcvad.Vad(int(args.aggressiveness))
|
vad = webrtcvad.Vad(int(args.aggressiveness))
|
||||||
frames = frame_generator(30, audio, sample_rate)
|
frames = frame_generator(30, audio, sample_rate)
|
||||||
|
@ -180,6 +182,7 @@ def remove_silence(filepath):
|
||||||
# if fail to remove silence just write the file
|
# if fail to remove silence just write the file
|
||||||
write_wave(output_path, audio, sample_rate)
|
write_wave(output_path, audio, sample_rate)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_audios():
|
def preprocess_audios():
|
||||||
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
|
files = sorted(glob.glob(os.path.join(args.input_dir, args.glob), recursive=True))
|
||||||
print("> Number of files: ", len(files))
|
print("> Number of files: ", len(files))
|
||||||
|
@ -193,21 +196,31 @@ def preprocess_audios():
|
||||||
else:
|
else:
|
||||||
print("> No files Found !")
|
print("> No files Found !")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
"""
|
||||||
usage
|
usage
|
||||||
python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2
|
python remove_silence.py -i=VCTK-Corpus-bk/ -o=../VCTK-Corpus-removed-silence -g=wav48/*/*.wav -a=2
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-i', '--input_dir', type=str, default='../VCTK-Corpus',
|
parser.add_argument("-i", "--input_dir", type=str, default="../VCTK-Corpus", help="Dataset root dir")
|
||||||
help='Dataset root dir')
|
parser.add_argument(
|
||||||
parser.add_argument('-o', '--output_dir', type=str, default='../VCTK-Corpus-removed-silence',
|
"-o", "--output_dir", type=str, default="../VCTK-Corpus-removed-silence", help="Output Dataset dir"
|
||||||
help='Output Dataset dir')
|
)
|
||||||
parser.add_argument('-f', '--force', type=bool, default=True,
|
parser.add_argument("-f", "--force", type=bool, default=True, help="Force the replace of exists files")
|
||||||
help='Force the replace of exists files')
|
parser.add_argument(
|
||||||
parser.add_argument('-g', '--glob', type=str, default='**/*.wav',
|
"-g",
|
||||||
help='path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav')
|
"--glob",
|
||||||
parser.add_argument('-a', '--aggressiveness', type=int, default=2,
|
type=str,
|
||||||
help='set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.')
|
default="**/*.wav",
|
||||||
|
help="path in glob format for acess wavs from input_dir. ex: wav48/*/*.wav",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-a",
|
||||||
|
"--aggressiveness",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="set its aggressiveness mode, which is an integer between 0 and 3. 0 is the least aggressive about filtering out non-speech, 3 is the most aggressive.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
preprocess_audios()
|
preprocess_audios()
|
||||||
|
|
|
@ -5,20 +5,20 @@ import torch.nn as nn
|
||||||
|
|
||||||
from TTS.utils.io import load_fsspec
|
from TTS.utils.io import load_fsspec
|
||||||
|
|
||||||
|
|
||||||
class PreEmphasis(torch.nn.Module):
|
class PreEmphasis(torch.nn.Module):
|
||||||
def __init__(self, coefficient=0.97):
|
def __init__(self, coefficient=0.97):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.coefficient = coefficient
|
self.coefficient = coefficient
|
||||||
self.register_buffer(
|
self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
|
||||||
'filter', torch.FloatTensor([-self.coefficient, 1.]).unsqueeze(0).unsqueeze(0)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
assert len(x.size()) == 2
|
assert len(x.size()) == 2
|
||||||
|
|
||||||
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), 'reflect')
|
x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
|
||||||
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
|
||||||
|
|
||||||
|
|
||||||
class SELayer(nn.Module):
|
class SELayer(nn.Module):
|
||||||
def __init__(self, channel, reduction=8):
|
def __init__(self, channel, reduction=8):
|
||||||
super(SELayer, self).__init__()
|
super(SELayer, self).__init__()
|
||||||
|
@ -110,8 +110,15 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
if self.use_torch_spec:
|
if self.use_torch_spec:
|
||||||
self.torch_spec = torch.nn.Sequential(
|
self.torch_spec = torch.nn.Sequential(
|
||||||
PreEmphasis(audio_config["preemphasis"]),
|
PreEmphasis(audio_config["preemphasis"]),
|
||||||
torchaudio.transforms.MelSpectrogram(sample_rate=audio_config["sample_rate"], n_fft=audio_config["fft_size"], win_length=audio_config["win_length"], hop_length=audio_config["hop_length"], window_fn=torch.hamming_window, n_mels=audio_config["num_mels"])
|
torchaudio.transforms.MelSpectrogram(
|
||||||
)
|
sample_rate=audio_config["sample_rate"],
|
||||||
|
n_fft=audio_config["fft_size"],
|
||||||
|
win_length=audio_config["win_length"],
|
||||||
|
hop_length=audio_config["hop_length"],
|
||||||
|
window_fn=torch.hamming_window,
|
||||||
|
n_mels=audio_config["num_mels"],
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.torch_spec = None
|
self.torch_spec = None
|
||||||
|
|
||||||
|
@ -213,7 +220,7 @@ class ResNetSpeakerEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
# map to the waveform size
|
# map to the waveform size
|
||||||
if self.use_torch_spec:
|
if self.use_torch_spec:
|
||||||
num_frames = num_frames * self.audio_config['hop_length']
|
num_frames = num_frames * self.audio_config["hop_length"]
|
||||||
|
|
||||||
max_len = x.shape[1]
|
max_len = x.shape[1]
|
||||||
|
|
||||||
|
|
|
@ -179,10 +179,12 @@ def setup_model(c):
|
||||||
c.model_params["num_lstm_layers"],
|
c.model_params["num_lstm_layers"],
|
||||||
)
|
)
|
||||||
elif c.model_params["model_name"].lower() == "resnet":
|
elif c.model_params["model_name"].lower() == "resnet":
|
||||||
model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"],
|
model = ResNetSpeakerEncoder(
|
||||||
|
input_dim=c.model_params["input_dim"],
|
||||||
|
proj_dim=c.model_params["proj_dim"],
|
||||||
log_input=c.model_params.get("log_input", False),
|
log_input=c.model_params.get("log_input", False),
|
||||||
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
use_torch_spec=c.model_params.get("use_torch_spec", False),
|
||||||
audio_config=c.audio
|
audio_config=c.audio,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -265,7 +265,9 @@ class Trainer:
|
||||||
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
|
config = self.config.model_args if hasattr(self.config, "model_args") else self.config
|
||||||
# save speakers json
|
# save speakers json
|
||||||
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
|
if config.use_language_embedding and self.model.language_manager.num_languages > 1:
|
||||||
self.model.language_manager.save_language_ids_to_file(os.path.join(self.output_path, "language_ids.json"))
|
self.model.language_manager.save_language_ids_to_file(
|
||||||
|
os.path.join(self.output_path, "language_ids.json")
|
||||||
|
)
|
||||||
if hasattr(self.config, "model_args"):
|
if hasattr(self.config, "model_args"):
|
||||||
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
|
self.config.model_args["num_languages"] = self.model.language_manager.num_languages
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -542,6 +542,7 @@ class TTSDataset(Dataset):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PitchExtractor:
|
class PitchExtractor:
|
||||||
"""Pitch Extractor for computing F0 from wav files.
|
"""Pitch Extractor for computing F0 from wav files.
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -304,7 +304,7 @@ def vctk(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None):
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument
|
def vctk_slim(root_path, meta_files=None, wavs_path="wav48", ununsed_speakers=None): # pylint: disable=unused-argument
|
||||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||||
items = []
|
items = []
|
||||||
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||||
|
|
|
@ -602,7 +602,7 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
fine_tuning_mode=0,
|
fine_tuning_mode=0,
|
||||||
use_speaker_encoder_as_loss=False,
|
use_speaker_encoder_as_loss=False,
|
||||||
gt_spk_emb=None,
|
gt_spk_emb=None,
|
||||||
syn_spk_emb=None
|
syn_spk_emb=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
|
@ -638,7 +638,9 @@ class VitsGeneratorLoss(nn.Module):
|
||||||
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration
|
||||||
|
|
||||||
if use_speaker_encoder_as_loss:
|
if use_speaker_encoder_as_loss:
|
||||||
loss_se = - torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
|
loss_se = (
|
||||||
|
-torch.nn.functional.cosine_similarity(gt_spk_emb, syn_spk_emb).mean() * self.spk_encoder_loss_alpha
|
||||||
|
)
|
||||||
loss += loss_se
|
loss += loss_se
|
||||||
return_dict["loss_spk_encoder"] = loss_se
|
return_dict["loss_spk_encoder"] = loss_se
|
||||||
|
|
||||||
|
|
|
@ -178,7 +178,14 @@ class StochasticDurationPredictor(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_channels: int, hidden_channels: int, kernel_size: int, dropout_p: float, num_flows=4, cond_channels=0, language_emb_dim=None
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
hidden_channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
dropout_p: float,
|
||||||
|
num_flows=4,
|
||||||
|
cond_channels=0,
|
||||||
|
language_emb_dim=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
|
@ -246,7 +246,9 @@ class BaseTTS(BaseModel):
|
||||||
# setup multi-speaker attributes
|
# setup multi-speaker attributes
|
||||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||||
if hasattr(config, "model_args"):
|
if hasattr(config, "model_args"):
|
||||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
speaker_id_mapping = (
|
||||||
|
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
|
||||||
|
)
|
||||||
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
|
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
|
||||||
config.use_d_vector_file = config.model_args.use_d_vector_file
|
config.use_d_vector_file = config.model_args.use_d_vector_file
|
||||||
else:
|
else:
|
||||||
|
@ -262,7 +264,9 @@ class BaseTTS(BaseModel):
|
||||||
custom_symbols = self.make_symbols(self.config)
|
custom_symbols = self.make_symbols(self.config)
|
||||||
|
|
||||||
if hasattr(self, "language_manager"):
|
if hasattr(self, "language_manager"):
|
||||||
language_id_mapping = self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
language_id_mapping = (
|
||||||
|
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
language_id_mapping = None
|
language_id_mapping = None
|
||||||
|
|
||||||
|
|
|
@ -229,7 +229,6 @@ class VitsArgs(Coqpit):
|
||||||
freeze_waveform_decoder: bool = False
|
freeze_waveform_decoder: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Vits(BaseTTS):
|
class Vits(BaseTTS):
|
||||||
"""VITS TTS model
|
"""VITS TTS model
|
||||||
|
|
||||||
|
@ -306,7 +305,7 @@ class Vits(BaseTTS):
|
||||||
args.num_layers_text_encoder,
|
args.num_layers_text_encoder,
|
||||||
args.kernel_size_text_encoder,
|
args.kernel_size_text_encoder,
|
||||||
args.dropout_p_text_encoder,
|
args.dropout_p_text_encoder,
|
||||||
language_emb_dim=self.embedded_language_dim
|
language_emb_dim=self.embedded_language_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.posterior_encoder = PosteriorEncoder(
|
self.posterior_encoder = PosteriorEncoder(
|
||||||
|
@ -389,16 +388,26 @@ class Vits(BaseTTS):
|
||||||
# TODO: make this a function
|
# TODO: make this a function
|
||||||
if config.use_speaker_encoder_as_loss:
|
if config.use_speaker_encoder_as_loss:
|
||||||
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
if not config.speaker_encoder_model_path or not config.speaker_encoder_config_path:
|
||||||
raise RuntimeError(" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!")
|
raise RuntimeError(
|
||||||
self.speaker_manager.init_speaker_encoder(config.speaker_encoder_model_path, config.speaker_encoder_config_path)
|
" [!] To use the speaker encoder loss you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
|
||||||
|
)
|
||||||
|
self.speaker_manager.init_speaker_encoder(
|
||||||
|
config.speaker_encoder_model_path, config.speaker_encoder_config_path
|
||||||
|
)
|
||||||
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
|
self.speaker_encoder = self.speaker_manager.speaker_encoder.train()
|
||||||
for param in self.speaker_encoder.parameters():
|
for param in self.speaker_encoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
print(" > External Speaker Encoder Loaded !!")
|
print(" > External Speaker Encoder Loaded !!")
|
||||||
|
|
||||||
if hasattr(self.speaker_encoder, "audio_config") and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]:
|
if (
|
||||||
self.audio_transform = torchaudio.transforms.Resample(orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_encoder.audio_config["sample_rate"])
|
hasattr(self.speaker_encoder, "audio_config")
|
||||||
|
and self.audio_config["sample_rate"] != self.speaker_encoder.audio_config["sample_rate"]
|
||||||
|
):
|
||||||
|
self.audio_transform = torchaudio.transforms.Resample(
|
||||||
|
orig_freq=self.audio_config["sample_rate"],
|
||||||
|
new_freq=self.speaker_encoder.audio_config["sample_rate"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.audio_transform = None
|
self.audio_transform = None
|
||||||
else:
|
else:
|
||||||
|
@ -529,7 +538,13 @@ class Vits(BaseTTS):
|
||||||
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
|
||||||
language_id = self.language_manager.language_id_mapping[language_name]
|
language_id = self.language_manager.language_id_mapping[language_name]
|
||||||
|
|
||||||
return {"text": text, "speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector, "language_id": language_id}
|
return {
|
||||||
|
"text": text,
|
||||||
|
"speaker_id": speaker_id,
|
||||||
|
"style_wav": style_wav,
|
||||||
|
"d_vector": d_vector,
|
||||||
|
"language_id": language_id,
|
||||||
|
}
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -567,7 +582,7 @@ class Vits(BaseTTS):
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# language embedding
|
# language embedding
|
||||||
lang_emb=None
|
lang_emb = None
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
|
@ -621,9 +636,9 @@ class Vits(BaseTTS):
|
||||||
o = self.waveform_decoder(z_slice, g=g)
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
|
||||||
wav_seg = segment(
|
wav_seg = segment(
|
||||||
waveform.transpose(1, 2),
|
waveform.transpose(1, 2),
|
||||||
slice_ids * self.config.audio.hop_length,
|
slice_ids * self.config.audio.hop_length,
|
||||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
||||||
|
@ -653,7 +668,7 @@ class Vits(BaseTTS):
|
||||||
"logs_q": logs_q,
|
"logs_q": logs_q,
|
||||||
"waveform_seg": wav_seg,
|
"waveform_seg": wav_seg,
|
||||||
"gt_spk_emb": gt_spk_emb,
|
"gt_spk_emb": gt_spk_emb,
|
||||||
"syn_spk_emb": syn_spk_emb
|
"syn_spk_emb": syn_spk_emb,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -695,7 +710,7 @@ class Vits(BaseTTS):
|
||||||
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# language embedding
|
# language embedding
|
||||||
lang_emb=None
|
lang_emb = None
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
|
@ -737,9 +752,9 @@ class Vits(BaseTTS):
|
||||||
o = self.waveform_decoder(z_slice, g=g)
|
o = self.waveform_decoder(z_slice, g=g)
|
||||||
|
|
||||||
wav_seg = segment(
|
wav_seg = segment(
|
||||||
waveform.transpose(1, 2),
|
waveform.transpose(1, 2),
|
||||||
slice_ids * self.config.audio.hop_length,
|
slice_ids * self.config.audio.hop_length,
|
||||||
self.args.spec_segment_size * self.config.audio.hop_length,
|
self.args.spec_segment_size * self.config.audio.hop_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
if self.args.use_speaker_encoder_as_loss and self.speaker_encoder is not None:
|
||||||
|
@ -770,7 +785,7 @@ class Vits(BaseTTS):
|
||||||
"logs_q": logs_q,
|
"logs_q": logs_q,
|
||||||
"waveform_seg": wav_seg,
|
"waveform_seg": wav_seg,
|
||||||
"gt_spk_emb": gt_spk_emb,
|
"gt_spk_emb": gt_spk_emb,
|
||||||
"syn_spk_emb": syn_spk_emb
|
"syn_spk_emb": syn_spk_emb,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
@ -790,14 +805,16 @@ class Vits(BaseTTS):
|
||||||
g = self.emb_g(sid).unsqueeze(-1)
|
g = self.emb_g(sid).unsqueeze(-1)
|
||||||
|
|
||||||
# language embedding
|
# language embedding
|
||||||
lang_emb=None
|
lang_emb = None
|
||||||
if self.args.use_language_embedding and lid is not None:
|
if self.args.use_language_embedding and lid is not None:
|
||||||
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
lang_emb = self.emb_l(lid).unsqueeze(-1)
|
||||||
|
|
||||||
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb)
|
||||||
|
|
||||||
if self.args.use_sdp:
|
if self.args.use_sdp:
|
||||||
logw = self.duration_predictor(x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb)
|
logw = self.duration_predictor(
|
||||||
|
x, x_mask, g=g, reverse=True, noise_scale=self.inference_noise_scale_dp, lang_emb=lang_emb
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
|
logw = self.duration_predictor(x, x_mask, g=g, lang_emb=lang_emb)
|
||||||
|
|
||||||
|
@ -866,7 +883,7 @@ class Vits(BaseTTS):
|
||||||
for param in self.text_encoder.parameters():
|
for param in self.text_encoder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
if hasattr(self, 'emb_l'):
|
if hasattr(self, "emb_l"):
|
||||||
for param in self.emb_l.parameters():
|
for param in self.emb_l.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
@ -932,7 +949,7 @@ class Vits(BaseTTS):
|
||||||
with autocast(enabled=False): # use float32 for the criterion
|
with autocast(enabled=False): # use float32 for the criterion
|
||||||
loss_dict = criterion[optimizer_idx](
|
loss_dict = criterion[optimizer_idx](
|
||||||
waveform_hat=outputs["model_outputs"].float(),
|
waveform_hat=outputs["model_outputs"].float(),
|
||||||
waveform= outputs["waveform_seg"].float(),
|
waveform=outputs["waveform_seg"].float(),
|
||||||
z_p=outputs["z_p"].float(),
|
z_p=outputs["z_p"].float(),
|
||||||
logs_q=outputs["logs_q"].float(),
|
logs_q=outputs["logs_q"].float(),
|
||||||
m_p=outputs["m_p"].float(),
|
m_p=outputs["m_p"].float(),
|
||||||
|
@ -945,7 +962,7 @@ class Vits(BaseTTS):
|
||||||
fine_tuning_mode=self.args.fine_tuning_mode,
|
fine_tuning_mode=self.args.fine_tuning_mode,
|
||||||
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
use_speaker_encoder_as_loss=self.args.use_speaker_encoder_as_loss,
|
||||||
gt_spk_emb=outputs["gt_spk_emb"],
|
gt_spk_emb=outputs["gt_spk_emb"],
|
||||||
syn_spk_emb=outputs["syn_spk_emb"]
|
syn_spk_emb=outputs["syn_spk_emb"],
|
||||||
)
|
)
|
||||||
# ignore duration loss if fine tuning mode is on
|
# ignore duration loss if fine tuning mode is on
|
||||||
if not self.args.fine_tuning_mode:
|
if not self.args.fine_tuning_mode:
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import torch
|
import os
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Dict, Tuple, List
|
import torch
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
|
|
||||||
from torch.utils.data.sampler import WeightedRandomSampler
|
from torch.utils.data.sampler import WeightedRandomSampler
|
||||||
|
|
||||||
|
|
||||||
class LanguageManager:
|
class LanguageManager:
|
||||||
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
"""Manage the languages for multi-lingual 🐸TTS models. Load a datafile and parse the information
|
||||||
in a way that can be queried by language.
|
in a way that can be queried by language.
|
||||||
|
@ -20,7 +21,9 @@ class LanguageManager:
|
||||||
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
|
>>> manager = LanguageManager(language_id_file_path=language_id_file_path)
|
||||||
>>> language_id_mapper = manager.language_ids
|
>>> language_id_mapper = manager.language_ids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
language_id_mapping: Dict = {}
|
language_id_mapping: Dict = {}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
language_id_file_path: str = "",
|
language_id_file_path: str = "",
|
||||||
|
@ -85,6 +88,7 @@ class LanguageManager:
|
||||||
"""
|
"""
|
||||||
self._save_json(file_path, self.language_id_mapping)
|
self._save_json(file_path, self.language_id_mapping)
|
||||||
|
|
||||||
|
|
||||||
def _set_file_path(path):
|
def _set_file_path(path):
|
||||||
"""Find the language_ids.json under the given path or the above it.
|
"""Find the language_ids.json under the given path or the above it.
|
||||||
Intended to band aid the different paths returned in restored and continued training."""
|
Intended to band aid the different paths returned in restored and continued training."""
|
||||||
|
@ -97,6 +101,7 @@ def _set_file_path(path):
|
||||||
return path_continue
|
return path_continue
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
|
def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None) -> LanguageManager:
|
||||||
"""Initiate a `LanguageManager` instance by the provided config.
|
"""Initiate a `LanguageManager` instance by the provided config.
|
||||||
|
|
||||||
|
@ -118,7 +123,7 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None)
|
||||||
# restoring language manager from a previous run.
|
# restoring language manager from a previous run.
|
||||||
if language_file:
|
if language_file:
|
||||||
language_manager.set_language_ids_from_file(language_file)
|
language_manager.set_language_ids_from_file(language_file)
|
||||||
if language_manager.num_languages > 0:
|
if language_manager.num_languages > 0:
|
||||||
print(
|
print(
|
||||||
" > Language manager is loaded with {} languages: {}".format(
|
" > Language manager is loaded with {} languages: {}".format(
|
||||||
language_manager.num_languages, ", ".join(language_manager.language_names)
|
language_manager.num_languages, ", ".join(language_manager.language_names)
|
||||||
|
@ -126,11 +131,12 @@ def get_language_manager(c: Coqpit, data: List = None, restore_path: str = None)
|
||||||
)
|
)
|
||||||
return language_manager
|
return language_manager
|
||||||
|
|
||||||
|
|
||||||
def get_language_weighted_sampler(items: list):
|
def get_language_weighted_sampler(items: list):
|
||||||
language_names = np.array([item[3] for item in items])
|
language_names = np.array([item[3] for item in items])
|
||||||
unique_language_names = np.unique(language_names).tolist()
|
unique_language_names = np.unique(language_names).tolist()
|
||||||
language_ids = [unique_language_names.index(l) for l in language_names]
|
language_ids = [unique_language_names.index(l) for l in language_names]
|
||||||
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
language_count = np.array([len(np.where(language_names == l)[0]) for l in unique_language_names])
|
||||||
weight_language = 1. / language_count
|
weight_language = 1.0 / language_count
|
||||||
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
dataset_samples_weight = torch.from_numpy(np.array([weight_language[l] for l in language_ids])).double()
|
||||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
||||||
|
|
|
@ -432,11 +432,12 @@ def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None,
|
||||||
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
|
|
||||||
def get_speaker_weighted_sampler(items: list):
|
def get_speaker_weighted_sampler(items: list):
|
||||||
speaker_names = np.array([item[2] for item in items])
|
speaker_names = np.array([item[2] for item in items])
|
||||||
unique_speaker_names = np.unique(speaker_names).tolist()
|
unique_speaker_names = np.unique(speaker_names).tolist()
|
||||||
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
speaker_ids = [unique_speaker_names.index(l) for l in speaker_names]
|
||||||
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
speaker_count = np.array([len(np.where(speaker_names == l)[0]) for l in unique_speaker_names])
|
||||||
weight_speaker = 1. / speaker_count
|
weight_speaker = 1.0 / speaker_count
|
||||||
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
dataset_samples_weight = torch.from_numpy(np.array([weight_speaker[l] for l in speaker_ids])).double()
|
||||||
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
return WeightedRandomSampler(dataset_samples_weight, len(dataset_samples_weight))
|
|
@ -136,8 +136,9 @@ def phoneme_cleaners(text):
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def multilingual_cleaners(text):
|
def multilingual_cleaners(text):
|
||||||
'''Pipeline for multilingual text'''
|
"""Pipeline for multilingual text"""
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = replace_symbols(text, lang=None)
|
text = replace_symbols(text, lang=None)
|
||||||
text = remove_aux_symbols(text)
|
text = remove_aux_symbols(text)
|
||||||
|
|
|
@ -3,19 +3,27 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from tests import get_device_id, get_tests_output_path, run_cli
|
from tests import get_device_id, get_tests_output_path, run_cli
|
||||||
from TTS.tts.configs.vits_config import VitsConfig
|
|
||||||
from TTS.config.shared_configs import BaseDatasetConfig
|
from TTS.config.shared_configs import BaseDatasetConfig
|
||||||
|
from TTS.tts.configs.vits_config import VitsConfig
|
||||||
|
|
||||||
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
||||||
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
||||||
|
|
||||||
|
|
||||||
dataset_config1 = BaseDatasetConfig(
|
dataset_config1 = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en"
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="en",
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_config2 = BaseDatasetConfig(
|
dataset_config2 = BaseDatasetConfig(
|
||||||
name="ljspeech", meta_file_train="metadata.csv", meta_file_val="metadata.csv", path="tests/data/ljspeech", language="en2"
|
name="ljspeech",
|
||||||
|
meta_file_train="metadata.csv",
|
||||||
|
meta_file_val="metadata.csv",
|
||||||
|
path="tests/data/ljspeech",
|
||||||
|
language="en2",
|
||||||
)
|
)
|
||||||
|
|
||||||
config = VitsConfig(
|
config = VitsConfig(
|
||||||
|
|
Loading…
Reference in New Issue