mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' of https://github.com/mozilla/TTS into dev
This commit is contained in:
commit
0f78f5c277
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"model": "Tacotron2", // one of the model in models/
|
"model": "Tacotron2", // one of the model in models/
|
||||||
"run_name": "ljspeech-gravesv2",
|
"run_name": "ljspeech-stft_params",
|
||||||
"run_description": "tacotron2 wuth graves attention",
|
"run_description": "tacotron2 cosntant stf parameters",
|
||||||
|
|
||||||
// AUDIO PARAMETERS
|
// AUDIO PARAMETERS
|
||||||
"audio":{
|
"audio":{
|
||||||
|
@ -50,12 +50,11 @@
|
||||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||||
|
|
||||||
// TRAINING
|
// TRAINING
|
||||||
"batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||||
"eval_batch_size":16,
|
"eval_batch_size":16,
|
||||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||||
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
|
||||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||||
"grad_accum": 2, // if N > 1, enable gradient accumulation for N iterations. It is useful for low memory GPUs.
|
|
||||||
|
|
||||||
// VALIDATION
|
// VALIDATION
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
|
@ -110,7 +109,7 @@
|
||||||
"output_path": "/data4/rw/home/Trainings/",
|
"output_path": "/data4/rw/home/Trainings/",
|
||||||
|
|
||||||
// PHONEMES
|
// PHONEMES
|
||||||
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
"phoneme_cache_path": "mozilla_us_phonemes_2_1", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||||
|
|
||||||
|
|
|
@ -96,3 +96,32 @@ class AttentionEntropyLoss(nn.Module):
|
||||||
entropy = torch.distributions.Categorical(probs=align).entropy()
|
entropy = torch.distributions.Categorical(probs=align).entropy()
|
||||||
loss = (entropy / np.log(align.shape[1])).mean()
|
loss = (entropy / np.log(align.shape[1])).mean()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class BCELossMasked(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, pos_weight):
|
||||||
|
super(BCELossMasked, self).__init__()
|
||||||
|
self.pos_weight = pos_weight
|
||||||
|
|
||||||
|
def forward(self, x, target, length):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: A Variable containing a FloatTensor of size
|
||||||
|
(batch, max_len) which contains the
|
||||||
|
unnormalized probability for each class.
|
||||||
|
target: A Variable containing a LongTensor of size
|
||||||
|
(batch, max_len) which contains the index of the true
|
||||||
|
class for each corresponding step.
|
||||||
|
length: A Variable containing a LongTensor of size (batch,)
|
||||||
|
which contains the length of each data in a batch.
|
||||||
|
Returns:
|
||||||
|
loss: An average loss value in range [0, 1] masked by the length.
|
||||||
|
"""
|
||||||
|
# mask: (batch, max_len, 1)
|
||||||
|
target.requires_grad = False
|
||||||
|
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
|
||||||
|
loss = functional.binary_cross_entropy_with_logits(
|
||||||
|
x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum')
|
||||||
|
loss = loss / mask.sum()
|
||||||
|
return loss
|
||||||
|
|
|
@ -39,7 +39,7 @@ class Tacotron(nn.Module):
|
||||||
encoder_dim = 512 if num_speakers > 1 else 256
|
encoder_dim = 512 if num_speakers > 1 else 256
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
# embedding layer
|
# embedding layer
|
||||||
self.embedding = nn.Embedding(num_chars, 256)
|
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
# boilerplate model
|
# boilerplate model
|
||||||
self.encoder = Encoder(encoder_dim)
|
self.encoder = Encoder(encoder_dim)
|
||||||
|
|
|
@ -35,7 +35,7 @@ class Tacotron2(nn.Module):
|
||||||
encoder_dim = 512 if num_speakers > 1 else 512
|
encoder_dim = 512 if num_speakers > 1 else 512
|
||||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||||
# embedding layer
|
# embedding layer
|
||||||
self.embedding = nn.Embedding(num_chars, 512)
|
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||||
std = sqrt(2.0 / (num_chars + 512))
|
std = sqrt(2.0 / (num_chars + 512))
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding.weight.data.uniform_(-val, val)
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
|
|
|
@ -85,7 +85,10 @@
|
||||||
" if use_cuda:\n",
|
" if use_cuda:\n",
|
||||||
" waveform = waveform.cpu()\n",
|
" waveform = waveform.cpu()\n",
|
||||||
" waveform = waveform.numpy()\n",
|
" waveform = waveform.numpy()\n",
|
||||||
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
|
" rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n",
|
||||||
|
" print(waveform.shape)\n",
|
||||||
|
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
|
||||||
|
" print(\" > Real-time factor: {}\".format(rtf))\n",
|
||||||
" if figures: \n",
|
" if figures: \n",
|
||||||
" visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec)) \n",
|
" visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec)) \n",
|
||||||
" IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=False)) \n",
|
" IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=False)) \n",
|
||||||
|
|
|
@ -105,8 +105,8 @@ class Synthesizer(object):
|
||||||
sample_rate=self.ap.sample_rate,
|
sample_rate=self.ap.sample_rate,
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
check = torch.load(model_file)
|
check = torch.load(model_file, map_location="cpu")
|
||||||
self.wavernn.load_state_dict(check['model'], map_location="cpu")
|
self.wavernn.load_state_dict(check['model'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.wavernn.cuda()
|
self.wavernn.cuda()
|
||||||
self.wavernn.eval()
|
self.wavernn.eval()
|
||||||
|
|
15
train.py
15
train.py
|
@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
|
||||||
from TTS.datasets.TTSDataset import MyDataset
|
from TTS.datasets.TTSDataset import MyDataset
|
||||||
from distribute import (DistributedSampler, apply_gradient_allreduce,
|
from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||||
init_distributed, reduce_tensor)
|
init_distributed, reduce_tensor)
|
||||||
from TTS.layers.losses import L1LossMasked, MSELossMasked
|
from TTS.layers.losses import L1LossMasked, MSELossMasked, BCELossMasked
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import (
|
from TTS.utils.generic_utils import (
|
||||||
NoamLR, check_update, count_parameters, create_experiment_folder,
|
NoamLR, check_update, count_parameters, create_experiment_folder,
|
||||||
|
@ -168,7 +168,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(stop_tokens,
|
stop_loss = criterion_st(stop_tokens,
|
||||||
stop_targets) if c.stopnet else torch.zeros(1)
|
stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
|
@ -366,7 +366,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(
|
stop_loss = criterion_st(
|
||||||
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input,
|
decoder_loss = criterion(decoder_output, mel_input,
|
||||||
mel_lengths)
|
mel_lengths)
|
||||||
|
@ -494,7 +494,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
use_cuda,
|
use_cuda,
|
||||||
ap,
|
ap,
|
||||||
speaker_id=speaker_id,
|
speaker_id=speaker_id,
|
||||||
style_wav=style_wav)
|
style_wav=style_wav,
|
||||||
|
truncated=False,
|
||||||
|
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||||
|
use_griffin_lim=True,
|
||||||
|
do_trim_silence=False)
|
||||||
|
|
||||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
file_path = os.path.join(file_path,
|
file_path = os.path.join(file_path,
|
||||||
|
@ -570,7 +575,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
else:
|
else:
|
||||||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
||||||
] else nn.MSELoss()
|
] else nn.MSELoss()
|
||||||
criterion_st = nn.BCEWithLogitsLoss(
|
criterion_st = BCELossMasked(
|
||||||
pos_weight=torch.tensor(10)) if c.stopnet else None
|
pos_weight=torch.tensor(10)) if c.stopnet else None
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
|
|
|
@ -14,7 +14,7 @@ def prepare_data(inputs):
|
||||||
|
|
||||||
|
|
||||||
def _pad_tensor(x, length):
|
def _pad_tensor(x, length):
|
||||||
_pad = 0
|
_pad = 0.
|
||||||
assert x.ndim == 2
|
assert x.ndim == 2
|
||||||
x = np.pad(
|
x = np.pad(
|
||||||
x, [[0, 0], [0, length - x.shape[1]]],
|
x, [[0, 0], [0, length - x.shape[1]]],
|
||||||
|
@ -31,7 +31,7 @@ def prepare_tensor(inputs, out_steps):
|
||||||
|
|
||||||
|
|
||||||
def _pad_stop_target(x, length):
|
def _pad_stop_target(x, length):
|
||||||
_pad = 1.
|
_pad = 0.
|
||||||
assert x.ndim == 1
|
assert x.ndim == 1
|
||||||
return np.pad(
|
return np.pad(
|
||||||
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||||
|
|
|
@ -391,7 +391,9 @@ class KeepAverage():
|
||||||
self.update_value(key, value)
|
self.update_value(key, value)
|
||||||
|
|
||||||
|
|
||||||
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None):
|
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
|
||||||
|
if alternative in c.keys() and c[alternative] is not None:
|
||||||
|
return
|
||||||
if restricted:
|
if restricted:
|
||||||
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
||||||
if name in c.keys():
|
if name in c.keys():
|
||||||
|
@ -417,8 +419,8 @@ def check_config(c):
|
||||||
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||||
_check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
_check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||||
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||||
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000)
|
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||||
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000)
|
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||||
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||||
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||||
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||||
|
|
|
@ -70,6 +70,24 @@ def id_to_torch(speaker_id):
|
||||||
return speaker_id
|
return speaker_id
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: perform GL with pytorch for batching
|
||||||
|
def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
||||||
|
'''Apply griffin-lim to each sample iterating throught the first dimension.
|
||||||
|
Args:
|
||||||
|
inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size.
|
||||||
|
input_lens (Tensor or np.Array): 1D array of sample lengths.
|
||||||
|
CONFIG (Dict): TTS config.
|
||||||
|
ap (AudioProcessor): TTS audio processor.
|
||||||
|
'''
|
||||||
|
wavs = []
|
||||||
|
for idx, spec in enumerate(inputs):
|
||||||
|
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
|
||||||
|
wav = inv_spectrogram(spec, ap, CONFIG)
|
||||||
|
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
|
||||||
|
wavs.append(wav[:wav_len])
|
||||||
|
return wavs
|
||||||
|
|
||||||
|
|
||||||
def synthesis(model,
|
def synthesis(model,
|
||||||
text,
|
text,
|
||||||
CONFIG,
|
CONFIG,
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from packaging import version
|
||||||
import phonemizer
|
import phonemizer
|
||||||
from phonemizer.phonemize import phonemize
|
from phonemizer.phonemize import phonemize
|
||||||
from TTS.utils.text import cleaners
|
from TTS.utils.text import cleaners
|
||||||
|
@ -28,7 +29,7 @@ def text2phone(text, language):
|
||||||
seperator = phonemizer.separator.Separator(' |', '', '|')
|
seperator = phonemizer.separator.Separator(' |', '', '|')
|
||||||
#try:
|
#try:
|
||||||
punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text)
|
punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text)
|
||||||
if float(phonemizer.__version__) < 2.1:
|
if version.parse(phonemizer.__version__) < version.parse('2.1'):
|
||||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
|
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
|
||||||
ph = ph[:-1].strip() # skip the last empty character
|
ph = ph[:-1].strip() # skip the last empty character
|
||||||
# phonemizer does not tackle punctuations. Here we do.
|
# phonemizer does not tackle punctuations. Here we do.
|
||||||
|
@ -42,7 +43,7 @@ def text2phone(text, language):
|
||||||
else:
|
else:
|
||||||
for punct in punctuations:
|
for punct in punctuations:
|
||||||
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
||||||
elif float(phonemizer.__version__) > 2.1:
|
elif version.parse(phonemizer.__version__) >= version.parse('2.1'):
|
||||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True)
|
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True)
|
||||||
# this is a simple fix for phonemizer.
|
# this is a simple fix for phonemizer.
|
||||||
# https://github.com/bootphon/phonemizer/issues/32
|
# https://github.com/bootphon/phonemizer/issues/32
|
||||||
|
|
|
@ -63,6 +63,19 @@ def convert_to_ascii(text):
|
||||||
return unidecode(text)
|
return unidecode(text)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_aux_symbols(text):
|
||||||
|
text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def replace_symbols(text):
|
||||||
|
text = text.replace(';', ',')
|
||||||
|
text = text.replace('-', ' ')
|
||||||
|
text = text.replace(':', ' ')
|
||||||
|
text = text.replace('&', 'and')
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def basic_cleaners(text):
|
def basic_cleaners(text):
|
||||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
|
@ -84,6 +97,8 @@ def english_cleaners(text):
|
||||||
text = lowercase(text)
|
text = lowercase(text)
|
||||||
text = expand_numbers(text)
|
text = expand_numbers(text)
|
||||||
text = expand_abbreviations(text)
|
text = expand_abbreviations(text)
|
||||||
|
text = replace_symbols(text)
|
||||||
|
text = remove_aux_symbols(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
@ -93,5 +108,7 @@ def phoneme_cleaners(text):
|
||||||
text = convert_to_ascii(text)
|
text = convert_to_ascii(text)
|
||||||
text = expand_numbers(text)
|
text = expand_numbers(text)
|
||||||
text = expand_abbreviations(text)
|
text = expand_abbreviations(text)
|
||||||
|
text = replace_symbols(text)
|
||||||
|
text = remove_aux_symbols(text)
|
||||||
text = collapse_whitespace(text)
|
text = collapse_whitespace(text)
|
||||||
return text
|
return text
|
||||||
|
|
Loading…
Reference in New Issue