Merge branch 'dev' of https://github.com/mozilla/TTS into dev

This commit is contained in:
erogol 2020-03-09 10:25:13 +01:00
commit 0f78f5c277
13 changed files with 97 additions and 23 deletions

View File

@ -1,7 +1,7 @@
{
"model": "Tacotron2", // one of the model in models/
"run_name": "ljspeech-gravesv2",
"run_description": "tacotron2 wuth graves attention",
"run_name": "ljspeech-stft_params",
"run_description": "tacotron2 cosntant stf parameters",
// AUDIO PARAMETERS
"audio":{
@ -50,12 +50,11 @@
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
// TRAINING
"batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":16,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"grad_accum": 2, // if N > 1, enable gradient accumulation for N iterations. It is useful for low memory GPUs.
// VALIDATION
"run_eval": true,
@ -110,7 +109,7 @@
"output_path": "/data4/rw/home/Trainings/",
// PHONEMES
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
"phoneme_cache_path": "mozilla_us_phonemes_2_1", // phoneme computation is slow, therefore, it caches results in the given folder.
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages

View File

@ -84,7 +84,7 @@ def mozilla_de(root_path, meta_file):
for line in ttf:
cols = line.strip().split('|')
wav_file = cols[0].strip()
text = cols[1].strip()
text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file)
items.append([text, wav_file, speaker_name])

View File

@ -96,3 +96,32 @@ class AttentionEntropyLoss(nn.Module):
entropy = torch.distributions.Categorical(probs=align).entropy()
loss = (entropy / np.log(align.shape[1])).mean()
return loss
class BCELossMasked(nn.Module):
def __init__(self, pos_weight):
super(BCELossMasked, self).__init__()
self.pos_weight = pos_weight
def forward(self, x, target, length):
"""
Args:
x: A Variable containing a FloatTensor of size
(batch, max_len) which contains the
unnormalized probability for each class.
target: A Variable containing a LongTensor of size
(batch, max_len) which contains the index of the true
class for each corresponding step.
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value in range [0, 1] masked by the length.
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
loss = functional.binary_cross_entropy_with_logits(
x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum')
loss = loss / mask.sum()
return loss

View File

@ -39,7 +39,7 @@ class Tacotron(nn.Module):
encoder_dim = 512 if num_speakers > 1 else 256
proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer
self.embedding = nn.Embedding(num_chars, 256)
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
self.embedding.weight.data.normal_(0, 0.3)
# boilerplate model
self.encoder = Encoder(encoder_dim)

View File

@ -35,7 +35,7 @@ class Tacotron2(nn.Module):
encoder_dim = 512 if num_speakers > 1 else 512
proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer
self.embedding = nn.Embedding(num_chars, 512)
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)

View File

@ -85,7 +85,10 @@
" if use_cuda:\n",
" waveform = waveform.cpu()\n",
" waveform = waveform.numpy()\n",
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
" rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n",
" print(waveform.shape)\n",
" print(\" > Run-time: {}\".format(time.time() - t_1))\n",
" print(\" > Real-time factor: {}\".format(rtf))\n",
" if figures: \n",
" visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec)) \n",
" IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=False)) \n",

View File

@ -105,8 +105,8 @@ class Synthesizer(object):
sample_rate=self.ap.sample_rate,
).cuda()
check = torch.load(model_file)
self.wavernn.load_state_dict(check['model'], map_location="cpu")
check = torch.load(model_file, map_location="cpu")
self.wavernn.load_state_dict(check['model'])
if use_cuda:
self.wavernn.cuda()
self.wavernn.eval()

View File

@ -13,7 +13,7 @@ from torch.utils.data import DataLoader
from TTS.datasets.TTSDataset import MyDataset
from distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.layers.losses import L1LossMasked, MSELossMasked
from TTS.layers.losses import L1LossMasked, MSELossMasked, BCELossMasked
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (
NoamLR, check_update, count_parameters, create_experiment_folder,
@ -168,7 +168,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# loss computation
stop_loss = criterion_st(stop_tokens,
stop_targets) if c.stopnet else torch.zeros(1)
stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model in ["Tacotron", "TacotronGST"]:
@ -366,7 +366,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
# loss computation
stop_loss = criterion_st(
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input,
mel_lengths)
@ -494,7 +494,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
use_cuda,
ap,
speaker_id=speaker_id,
style_wav=style_wav)
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
use_griffin_lim=True,
do_trim_silence=False)
file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
@ -570,7 +575,7 @@ def main(args): # pylint: disable=redefined-outer-name
else:
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss(
criterion_st = BCELossMasked(
pos_weight=torch.tensor(10)) if c.stopnet else None
if args.restore_path:

View File

@ -14,7 +14,7 @@ def prepare_data(inputs):
def _pad_tensor(x, length):
_pad = 0
_pad = 0.
assert x.ndim == 2
x = np.pad(
x, [[0, 0], [0, length - x.shape[1]]],
@ -31,7 +31,7 @@ def prepare_tensor(inputs, out_steps):
def _pad_stop_target(x, length):
_pad = 1.
_pad = 0.
assert x.ndim == 1
return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)

View File

@ -391,7 +391,9 @@ class KeepAverage():
self.update_value(key, value)
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None):
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
if alternative in c.keys() and c[alternative] is not None:
return
if restricted:
assert name in c.keys(), f' [!] {name} not defined in config.json'
if name in c.keys():
@ -417,8 +419,8 @@ def check_config(c):
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
_check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000)
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000)
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)

View File

@ -70,6 +70,24 @@ def id_to_torch(speaker_id):
return speaker_id
# TODO: perform GL with pytorch for batching
def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
'''Apply griffin-lim to each sample iterating throught the first dimension.
Args:
inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size.
input_lens (Tensor or np.Array): 1D array of sample lengths.
CONFIG (Dict): TTS config.
ap (AudioProcessor): TTS audio processor.
'''
wavs = []
for idx, spec in enumerate(inputs):
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
wav = inv_spectrogram(spec, ap, CONFIG)
# assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}"
wavs.append(wav[:wav_len])
return wavs
def synthesis(model,
text,
CONFIG,

View File

@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import re
from packaging import version
import phonemizer
from phonemizer.phonemize import phonemize
from TTS.utils.text import cleaners
@ -28,7 +29,7 @@ def text2phone(text, language):
seperator = phonemizer.separator.Separator(' |', '', '|')
#try:
punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text)
if float(phonemizer.__version__) < 2.1:
if version.parse(phonemizer.__version__) < version.parse('2.1'):
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
ph = ph[:-1].strip() # skip the last empty character
# phonemizer does not tackle punctuations. Here we do.
@ -42,7 +43,7 @@ def text2phone(text, language):
else:
for punct in punctuations:
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
elif float(phonemizer.__version__) > 2.1:
elif version.parse(phonemizer.__version__) >= version.parse('2.1'):
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True)
# this is a simple fix for phonemizer.
# https://github.com/bootphon/phonemizer/issues/32

View File

@ -63,6 +63,19 @@ def convert_to_ascii(text):
return unidecode(text)
def remove_aux_symbols(text):
text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text)
return text
def replace_symbols(text):
text = text.replace(';', ',')
text = text.replace('-', ' ')
text = text.replace(':', ' ')
text = text.replace('&', 'and')
return text
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = lowercase(text)
@ -84,6 +97,8 @@ def english_cleaners(text):
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
@ -93,5 +108,7 @@ def phoneme_cleaners(text):
text = convert_to_ascii(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = replace_symbols(text)
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text