diff --git a/config.json b/config.json index ef999fa9..2b2b03a5 100644 --- a/config.json +++ b/config.json @@ -1,7 +1,7 @@ { "model": "Tacotron2", // one of the model in models/ - "run_name": "ljspeech-graves", - "run_description": "tacotron2 wuth graves attention", + "run_name": "ljspeech-stft_params", + "run_description": "tacotron2 cosntant stf parameters", // AUDIO PARAMETERS "audio":{ @@ -9,8 +9,10 @@ "num_mels": 80, // size of the mel spec frame. "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. - "frame_length_ms": 50, // stft window length in ms. - "frame_shift_ms": 12.5, // stft window hop-lengh in ms. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. "min_level_db": -100, // normalization range "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. @@ -19,13 +21,26 @@ // Normalization parameters "signal_norm": true, // normalize the spec values in range [0, 1] "symmetric_norm": true, // move normalization to range [-1, 1] - "max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] "clip_norm": true, // clip normalized values into the range. "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! - "do_trim_silence": true // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) + "trim_db": 60 // threshold for timming silence. Set this according to your dataset. }, + // VOCABULARY PARAMETERS + // if custom character set is not defined, + // default set in symbols.py is used + "characters":{ + "pad": "_", + "eos": "~", + "bos": "^", + "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + "punctuations":"!'(),-.:;? ", + "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + }, + // DISTRIBUTED TRAINING "distributed":{ "backend": "nccl", @@ -48,11 +63,12 @@ // OPTIMIZER "noam_schedule": false, // use noam warmup and lr schedule. - "grad_clip": 1, // upper limit for gradients for clipping. + "grad_clip": 1.0, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "wd": 0.000001, // Weight decay weight. "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + "seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths. // TACOTRON PRENET "memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame. @@ -61,13 +77,13 @@ // ATTENTION "attention_type": "original", // 'original' or 'graves' - "attention_heads": 5, // number of attention heads (only for 'graves') + "attention_heads": 4, // number of attention heads (only for 'graves') "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "windowing": false, // Enables attention windowing. Used only in eval mode. "use_forward_attn": false, // if it uses forward attention. In general, it aligns faster. "forward_attn_mask": false, // Additional masking forcing monotonicity only in eval mode. "transition_agent": false, // enable/disable transition agent of forward attention. - "location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default. + "location_attn": false, // enable_disable location sensitive attention. It is enabled for TACOTRON by default. "bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset. // STOPNET @@ -90,11 +106,10 @@ "max_seq_len": 150, // DATASET-RELATED: maximum text length // PATHS - "output_path": "/data5/rw/pit/keep/", // DATASET-RELATED: output path for all training outputs. - // "output_path": "/media/erogol/data_ssd/Models/runs/", + "output_path": "/data4/rw/home/Trainings/", // PHONEMES - "phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder. + "phoneme_cache_path": "mozilla_us_phonemes_2_1", // phoneme computation is slow, therefore, it caches results in the given folder. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages @@ -108,10 +123,9 @@ [ { "name": "ljspeech", - "path": "/data5/ro/shared/data/keithito/LJSpeech-1.1/", - // "path": "/home/erogol/Data/LJSpeech-1.1", - "meta_file_train": "metadata_train.csv", - "meta_file_val": "metadata_val.csv" + "path": "/root/LJSpeech-1.1/", + "meta_file_train": "metadata.csv", + "meta_file_val": null } ] diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index a45d77ff..d3a6f486 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -15,6 +15,7 @@ class MyDataset(Dataset): text_cleaner, ap, meta_data, + tp=None, batch_group_size=0, min_seq_len=0, max_seq_len=float("inf"), @@ -49,6 +50,7 @@ class MyDataset(Dataset): self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap + self.tp = tp self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language @@ -75,13 +77,13 @@ class MyDataset(Dataset): def _generate_and_cache_phoneme_sequence(self, text, cache_path): """generate a phoneme sequence from text. - since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" phonemes = phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language, - enable_eos_bos=False) + enable_eos_bos=False, + tp=self.tp) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @@ -101,7 +103,7 @@ class MyDataset(Dataset): phonemes = self._generate_and_cache_phoneme_sequence(text, cache_path) if self.enable_eos_bos: - phonemes = pad_with_eos_bos(phonemes) + phonemes = pad_with_eos_bos(phonemes, tp=self.tp) phonemes = np.asarray(phonemes, dtype=np.int32) return phonemes @@ -113,7 +115,7 @@ class MyDataset(Dataset): text = self._load_or_generate_phoneme_sequence(wav_file, text) else: text = np.asarray( - text_to_sequence(text, [self.cleaners]), dtype=np.int32) + text_to_sequence(text, [self.cleaners], tp=self.tp), dtype=np.int32) assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] @@ -193,7 +195,7 @@ class MyDataset(Dataset): mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] linear = [self.ap.spectrogram(w).astype('float32') for w in wav] - mel_lengths = [m.shape[1] for m in mel] + mel_lengths = [m.shape[1] for m in mel] # compute 'stop token' targets stop_targets = [ diff --git a/datasets/preprocess.py b/datasets/preprocess.py index a78abab9..029922d3 100644 --- a/datasets/preprocess.py +++ b/datasets/preprocess.py @@ -60,22 +60,6 @@ def tweb(root_path, meta_file): # return {'text': texts, 'wavs': wavs} -def mozilla_old(root_path, meta_file): - """Normalizes Mozilla meta data files to TTS format""" - txt_file = os.path.join(root_path, meta_file) - items = [] - speaker_name = "mozilla_old" - with open(txt_file, 'r') as ttf: - for line in ttf: - cols = line.split('|') - batch_no = int(cols[1].strip().split("_")[0]) - wav_folder = "batch{}".format(batch_no) - wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip()) - text = cols[0].strip() - items.append([text, wav_file, speaker_name]) - return items - - def mozilla(root_path, meta_file): """Normalizes Mozilla meta data files to TTS format""" txt_file = os.path.join(root_path, meta_file) @@ -91,6 +75,22 @@ def mozilla(root_path, meta_file): return items +def mozilla_de(root_path, meta_file): + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, 'r', encoding="ISO 8859-1") as ttf: + for line in ttf: + cols = line.strip().split('|') + wav_file = cols[0].strip() + text = cols[1].strip() + folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" + wav_file = os.path.join(root_path, folder_name, wav_file) + items.append([text, wav_file, speaker_name]) + return items + + def mailabs(root_path, meta_files=None): """Normalizes M-AI-Labs meta data files to TTS format""" speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/") diff --git a/layers/common_layers.py b/layers/common_layers.py index c2b042b0..592f017c 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -110,6 +110,86 @@ class LocationLayer(nn.Module): return processed_attention +class GravesAttention(nn.Module): + """ Discretized Graves attention: + - https://arxiv.org/abs/1910.10288 + - https://arxiv.org/pdf/1906.01083.pdf + """ + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) + + def __init__(self, query_dim, K): + super(GravesAttention, self).__init__() + self._mask_value = 1e-8 + self.K = K + # self.attention_alignment = 0.05 + self.eps = 1e-5 + self.J = None + self.N_a = nn.Sequential( + nn.Linear(query_dim, query_dim, bias=True), + nn.ReLU(), + nn.Linear(query_dim, 3*K, bias=True)) + self.attention_weights = None + self.mu_prev = None + self.init_layers() + + def init_layers(self): + torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) # bias mean + torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) # bias std + + def init_states(self, inputs): + if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1]+2).to(inputs.device) + 0.5 + self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) + self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) + + # pylint: disable=R0201 + # pylint: disable=unused-argument + def preprocess_inputs(self, inputs): + return None + + def forward(self, query, inputs, processed_inputs, mask): + """ + shapes: + query: B x D_attention_rnn + inputs: B x T_in x D_encoder + processed_inputs: place_holder + mask: B x T_in + """ + gbk_t = self.N_a(query) + gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) + + # attention model parameters + # each B x K + g_t = gbk_t[:, 0, :] + b_t = gbk_t[:, 1, :] + k_t = gbk_t[:, 2, :] + + # attention GMM parameters + sig_t = torch.nn.functional.softplus(b_t) + self.eps + + mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) + g_t = torch.softmax(g_t, dim=-1) + self.eps + + j = self.J[:inputs.size(1)+1] + + # attention weights + phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1)))) + + # discritize attention weights + alpha_t = torch.sum(phi_t, 1) + alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] + alpha_t[alpha_t == 0] = 1e-8 + + # apply masking + if mask is not None: + alpha_t.data.masked_fill_(~mask, self._mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) + self.attention_weights = alpha_t + self.mu_prev = mu_t + return context + + class OriginalAttention(nn.Module): """Following the methods proposed here: - https://arxiv.org/abs/1712.05884 @@ -289,82 +369,6 @@ class OriginalAttention(nn.Module): return context -class GravesAttention(nn.Module): - """ Graves attention as described here: - - https://arxiv.org/abs/1910.10288 - """ - COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) - - def __init__(self, query_dim, K): - super(GravesAttention, self).__init__() - self._mask_value = 0.0 - self.K = K - # self.attention_alignment = 0.05 - self.eps = 1e-5 - self.J = None - self.N_a = nn.Sequential( - nn.Linear(query_dim, query_dim, bias=True), - nn.ReLU(), - nn.Linear(query_dim, 3*K, bias=True)) - self.attention_weights = None - self.mu_prev = None - self.init_layers() - - def init_layers(self): - torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) - torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) - - def init_states(self, inputs): - if self.J is None or inputs.shape[1] > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]).to(inputs.device) - self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) - self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) - - # pylint: disable=R0201 - # pylint: disable=unused-argument - def preprocess_inputs(self, inputs): - return None - - def forward(self, query, inputs, processed_inputs, mask): - """ - shapes: - query: B x D_attention_rnn - inputs: B x T_in x D_encoder - processed_inputs: place_holder - mask: B x T_in - """ - gbk_t = self.N_a(query) - gbk_t = gbk_t.view(gbk_t.size(0), -1, self.K) - - # attention model parameters - # each B x K - g_t = gbk_t[:, 0, :] - b_t = gbk_t[:, 1, :] - k_t = gbk_t[:, 2, :] - - # attention GMM parameters - sig_t = torch.nn.functional.softplus(b_t) + self.eps - - mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) - g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps - - # each B x K x T_in - j = self.J[:inputs.size(1)] - - # attention weights - phi_t = g_t.unsqueeze(-1) * torch.exp(-0.5 * (mu_t.unsqueeze(-1) - j)**2 / (sig_t.unsqueeze(-1)**2)) - alpha_t = self.COEF * torch.sum(phi_t, 1) - - # apply masking - if mask is not None: - alpha_t.data.masked_fill_(~mask, self._mask_value) - - context = torch.bmm(alpha_t.unsqueeze(1), inputs).squeeze(1) - self.attention_weights = alpha_t - self.mu_prev = mu_t - return context - - def init_attn(attn_type, query_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn, diff --git a/layers/losses.py b/layers/losses.py index e7ecff5f..7e5671b2 100644 --- a/layers/losses.py +++ b/layers/losses.py @@ -6,6 +6,11 @@ from TTS.utils.generic_utils import sequence_mask class L1LossMasked(nn.Module): + + def __init__(self, seq_len_norm): + super(L1LossMasked, self).__init__() + self.seq_len_norm = seq_len_norm + def forward(self, x, target, length): """ Args: @@ -24,14 +29,27 @@ class L1LossMasked(nn.Module): target.requires_grad = False mask = sequence_mask( sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() - mask = mask.expand_as(x) - loss = functional.l1_loss( - x * mask, target * mask, reduction="sum") - loss = loss / mask.sum() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.l1_loss( + x * mask, target * mask, reduction='none') + loss = loss.mul(out_weights.to(loss.device)).sum() + else: + mask = mask.expand_as(x) + loss = functional.l1_loss( + x * mask, target * mask, reduction='sum') + loss = loss / mask.sum() return loss class MSELossMasked(nn.Module): + + def __init__(self, seq_len_norm): + super(MSELossMasked, self).__init__() + self.seq_len_norm = seq_len_norm + def forward(self, x, target, length): """ Args: @@ -50,10 +68,18 @@ class MSELossMasked(nn.Module): target.requires_grad = False mask = sequence_mask( sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() - mask = mask.expand_as(x) - loss = functional.mse_loss( - x * mask, target * mask, reduction="sum") - loss = loss / mask.sum() + if self.seq_len_norm: + norm_w = mask / mask.sum(dim=1, keepdim=True) + out_weights = norm_w.div(target.shape[0] * target.shape[2]) + mask = mask.expand_as(x) + loss = functional.mse_loss( + x * mask, target * mask, reduction='none') + loss = loss.mul(out_weights.to(loss.device)).sum() + else: + mask = mask.expand_as(x) + loss = functional.mse_loss( + x * mask, target * mask, reduction='sum') + loss = loss / mask.sum() return loss @@ -70,3 +96,32 @@ class AttentionEntropyLoss(nn.Module): entropy = torch.distributions.Categorical(probs=align).entropy() loss = (entropy / np.log(align.shape[1])).mean() return loss + + +class BCELossMasked(nn.Module): + + def __init__(self, pos_weight): + super(BCELossMasked, self).__init__() + self.pos_weight = pos_weight + + def forward(self, x, target, length): + """ + Args: + x: A Variable containing a FloatTensor of size + (batch, max_len) which contains the + unnormalized probability for each class. + target: A Variable containing a LongTensor of size + (batch, max_len) which contains the index of the true + class for each corresponding step. + length: A Variable containing a LongTensor of size (batch,) + which contains the length of each data in a batch. + Returns: + loss: An average loss value in range [0, 1] masked by the length. + """ + # mask: (batch, max_len, 1) + target.requires_grad = False + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() + loss = functional.binary_cross_entropy_with_logits( + x * mask, target * mask, pos_weight=self.pos_weight, reduction='sum') + loss = loss / mask.sum() + return loss diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 78bdd10d..fa76a6b2 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -64,7 +64,6 @@ class Encoder(nn.Module): def forward(self, x, input_lengths): x = self.convolutions(x) x = x.transpose(1, 2) - input_lengths = input_lengths.cpu().numpy() x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) @@ -290,7 +289,7 @@ class Decoder(nn.Module): stop_tokens += [stop_token] alignments += [alignment] - if stop_token > 0.7: + if stop_token > 0.7 and t > inputs.shape[0] / 2: break if len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") diff --git a/models/tacotron.py b/models/tacotron.py index a2d9e1c4..fba82b1b 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -39,7 +39,7 @@ class Tacotron(nn.Module): encoder_dim = 512 if num_speakers > 1 else 256 proj_speaker_dim = 80 if num_speakers > 1 else 0 # embedding layer - self.embedding = nn.Embedding(num_chars, 256) + self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) self.embedding.weight.data.normal_(0, 0.3) # boilerplate model self.encoder = Encoder(encoder_dim) @@ -132,6 +132,7 @@ class Tacotron(nn.Module): return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward return decoder_outputs, postnet_outputs, alignments, stop_tokens + @torch.no_grad() def inference(self, characters, speaker_ids=None, style_mel=None): inputs = self.embedding(characters) self._init_states() diff --git a/models/tacotron2.py b/models/tacotron2.py index 852b1886..d530774a 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -35,7 +35,7 @@ class Tacotron2(nn.Module): encoder_dim = 512 if num_speakers > 1 else 512 proj_speaker_dim = 80 if num_speakers > 1 else 0 # embedding layer - self.embedding = nn.Embedding(num_chars, 512) + self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) std = sqrt(2.0 / (num_chars + 512)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) @@ -82,6 +82,7 @@ class Tacotron2(nn.Module): return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward return decoder_outputs, postnet_outputs, alignments, stop_tokens + @torch.no_grad() def inference(self, text, speaker_ids=None): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) diff --git a/notebooks/Benchmark-PWGAN.ipynb b/notebooks/Benchmark-PWGAN.ipynb new file mode 100644 index 00000000..082ffa60 --- /dev/null +++ b/notebooks/Benchmark-PWGAN.ipynb @@ -0,0 +1,585 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is to test TTS models with benchmark sentences for speech synthesis.\n", + "\n", + "Before running this script please DON'T FORGET: \n", + "- to set file paths.\n", + "- to download related model files from TTS and PWGAN.\n", + "- download or clone related repos, linked below.\n", + "- setup the repositories. ```python setup.py install```\n", + "- to checkout right commit versions (given next to the model) of TTS and PWGAN.\n", + "- to set the right paths in the cell below.\n", + "\n", + "Repositories:\n", + "- TTS: https://github.com/mozilla/TTS\n", + "- PWGAN: https://github.com/erogol/ParallelWaveGAN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import os\n", + "import sys\n", + "import io\n", + "import torch \n", + "import time\n", + "import json\n", + "import yaml\n", + "import numpy as np\n", + "from collections import OrderedDict\n", + "import matplotlib.pyplot as plt\n", + "plt.rcParams[\"figure.figsize\"] = (16,5)\n", + "\n", + "import librosa\n", + "import librosa.display\n", + "\n", + "from TTS.models.tacotron import Tacotron \n", + "from TTS.layers import *\n", + "from TTS.utils.data import *\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.generic_utils import load_config, setup_model\n", + "from TTS.utils.text import text_to_sequence\n", + "from TTS.utils.synthesis import synthesis\n", + "from TTS.utils.visual import visualize\n", + "\n", + "import IPython\n", + "from IPython.display import Audio\n", + "\n", + "import os\n", + "\n", + "# you may need to change this depending on your system\n", + "os.environ['CUDA_VISIBLE_DEVICES']='1'\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n", + " t_1 = time.time()\n", + " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n", + " if CONFIG.model == \"Tacotron\" and not use_gl:\n", + " # coorect the normalization differences b/w TTS and the Vocoder.\n", + " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", + " mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n", + "# mel_postnet_spec = np.pad(mel_postnet_spec, pad_width=((2, 2), (0, 0)))\n", + " print(mel_postnet_spec.shape)\n", + " print(\"max- \", mel_postnet_spec.max(), \" -- min- \", mel_postnet_spec.min())\n", + " if not use_gl:\n", + " waveform = vocoder_model.inference(torch.FloatTensor(ap_vocoder._normalize(mel_postnet_spec).T).unsqueeze(0), hop_size=ap_vocoder.hop_length)\n", + "# waveform = waveform / abs(waveform).max() * 0.9\n", + " if use_cuda:\n", + " waveform = waveform.cpu()\n", + " waveform = waveform.numpy()\n", + " rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n", + " print(waveform.shape)\n", + " print(\" > Run-time: {}\".format(time.time() - t_1))\n", + " print(\" > Real-time factor: {}\".format(rtf))\n", + " if figures: \n", + " visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec)) \n", + " IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=False)) \n", + " os.makedirs(OUT_FOLDER, exist_ok=True)\n", + " file_name = text.replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n", + " out_path = os.path.join(OUT_FOLDER, file_name)\n", + " ap.save_wav(waveform, out_path)\n", + " return alignment, mel_postnet_spec, stop_tokens, waveform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Set constants\n", + "ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-bn-December-23-2019_08+34AM-ffea133/'\n", + "MODEL_PATH = ROOT_PATH + '/checkpoint_670000.pth.tar'\n", + "CONFIG_PATH = ROOT_PATH + '/config.json'\n", + "OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n", + "CONFIG = load_config(CONFIG_PATH)\n", + "VOCODER_MODEL_PATH = \"/home/erogol/Models/LJSpeech/pwgan-ljspeech/checkpoint-400000steps.pkl\"\n", + "VOCODER_CONFIG_PATH = \"/home/erogol/Models/LJSpeech/pwgan-ljspeech/config.yml\"\n", + "\n", + "# load PWGAN config\n", + "with open(VOCODER_CONFIG_PATH) as f:\n", + " VOCODER_CONFIG = yaml.load(f, Loader=yaml.Loader)\n", + " \n", + "# Run FLAGs\n", + "use_cuda = False\n", + "# Set some config fields manually for testing\n", + "CONFIG.windowing = True\n", + "CONFIG.use_forward_attn = True \n", + "# Set the vocoder\n", + "use_gl = False # use GL if True\n", + "batched_wavernn = True # use batched wavernn inference if True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LOAD TTS MODEL\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", + "\n", + "# multi speaker \n", + "if CONFIG.use_speaker_embedding:\n", + " speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n", + " speakers_idx_to_id = {v: k for k, v in speakers.items()}\n", + "else:\n", + " speakers = []\n", + " speaker_id = None\n", + "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.characters)\n", + "\n", + "# load the model\n", + "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", + "model = setup_model(num_chars, len(speakers), CONFIG)\n", + "\n", + "# load the audio processor\n", + "ap = AudioProcessor(**CONFIG.audio) \n", + "\n", + "\n", + "# load model state\n", + "cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))\n", + "\n", + "# load the model\n", + "model.load_state_dict(cp['model'])\n", + "if use_cuda:\n", + " model.cuda()\n", + "model.eval()\n", + "print(cp['step'])\n", + "print(cp['r'])\n", + "\n", + "# set model stepsize\n", + "if 'r' in cp:\n", + " model.decoder.set_r(cp['r'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# LOAD WAVERNN\n", + "if use_gl == False:\n", + " from parallel_wavegan.models import ParallelWaveGANGenerator\n", + " from parallel_wavegan.utils.audio import AudioProcessor as AudioProcessorVocoder\n", + " \n", + " vocoder_model = ParallelWaveGANGenerator(**VOCODER_CONFIG[\"generator_params\"])\n", + " vocoder_model.load_state_dict(torch.load(VOCODER_MODEL_PATH, map_location=\"cpu\")[\"model\"][\"generator\"])\n", + " vocoder_model.remove_weight_norm()\n", + " ap_vocoder = AudioProcessorVocoder(**VOCODER_CONFIG['audio']) \n", + " if use_cuda:\n", + " vocoder_model.cuda()\n", + " vocoder_model.eval();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comparision with https://mycroft.ai/blog/available-voices/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.eval()\n", + "model.decoder.max_decoder_steps = 2000\n", + "model.decoder.prenet.eval()\n", + "speaker_id = None\n", + "sentence = '''A breeding jennet, lusty, young, and proud,'''\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Bill got in the habit of asking himself “Is that thought true?” and if he wasn’t absolutely certain it was, he just let it go.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### https://espnet.github.io/icassp2020-tts/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"The Commission also recommends\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"As a result of these studies, the planning document submitted by the Secretary of the Treasury to the Bureau of the Budget on August thirty-one.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"The FBI now transmits information on all defectors, a category which would, of course, have included Oswald.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"they seem unduly restrictive in continuing to require some manifestation of animus against a Government official.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"and each agency given clear understanding of the assistance which the Secret Service expects.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Other examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"The human voice is the most perfect instrument of all.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"This cake is great. It's so delicious and moist.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comparison with https://keithito.github.io/audio-samples/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Generative adversarial network or variational auto-encoder.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"The buses aren't the problem, they actually provide a solution.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Comparison with https://google.github.io/tacotron/publications/tacotron/index.html" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Generative adversarial network or variational auto-encoder.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \" He has read the whole thing.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"He reads books.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Thisss isrealy awhsome.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"This is your internet browser, Firefox.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"This is your internet browser Firefox.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"The quick brown fox jumps over the lazy dog.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Does the quick brown fox jump over the lazy dog?\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Eren, how are you?\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Hard Sentences" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Encouraged, he started with a minute a day.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"If he decided to watch TV he really watched it.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# for twb dataset\n", + "sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/Benchmark.ipynb b/notebooks/Benchmark.ipynb index 00ac7d16..7d3a45cf 100644 --- a/notebooks/Benchmark.ipynb +++ b/notebooks/Benchmark.ipynb @@ -65,7 +65,7 @@ "from TTS.utils.text import text_to_sequence\n", "from TTS.utils.synthesis import synthesis\n", "from TTS.utils.visual import visualize\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "import IPython\n", "from IPython.display import Audio\n", @@ -149,6 +149,10 @@ " speakers = []\n", " speaker_id = None\n", "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.characters)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n", diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index 20038f78..b5a88611 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -37,7 +37,7 @@ "from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.visual import plot_spectrogram\n", "from TTS.utils.generic_utils import load_config, setup_model, sequence_mask\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "%matplotlib inline\n", "\n", @@ -94,6 +94,10 @@ "metadata": {}, "outputs": [], "source": [ + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in C.keys():\n", + " symbols, phonemes = make_symbols(**C.characters)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "# TODO: multiple speaker\n", @@ -116,7 +120,7 @@ "preprocessor = importlib.import_module('TTS.datasets.preprocess')\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", - "dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", + "dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" ] }, diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb index a1867d13..9d3e5e75 100644 --- a/notebooks/TestAttention.ipynb +++ b/notebooks/TestAttention.ipynb @@ -100,7 +100,7 @@ "outputs": [], "source": [ "# LOAD TTS MODEL\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "# multi speaker \n", "if CONFIG.use_speaker_embedding:\n", @@ -110,6 +110,10 @@ " speakers = []\n", " speaker_id = None\n", "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.characters)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n", diff --git a/server/README.md b/server/README.md index 95297225..0563ef94 100644 --- a/server/README.md +++ b/server/README.md @@ -6,6 +6,10 @@ Instructions below are based on a Ubuntu 18.04 machine, but it should be simple #### Development server: +##### Using server.py +If you have the environment set already for TTS, then you can directly call ```setup.py```. + +##### Using .whl 1. apt-get install -y espeak libsndfile1 python3-venv 2. python3 -m venv /tmp/venv 3. source /tmp/venv/bin/activate diff --git a/server/server.py b/server/server.py index d40e2427..705937e2 100644 --- a/server/server.py +++ b/server/server.py @@ -14,30 +14,52 @@ def create_argparser(): parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file') parser.add_argument('--tts_config', type=str, help='path to TTS config.json file') parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') - parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') - parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.') - parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.') + parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.') + parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') + parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.') + parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') return parser -config = None synthesizer = None -embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar') -config_file = os.path.join(embedded_model_folder, 'config.json') +embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -if os.path.isfile(checkpoint_file) and os.path.isfile(config_file): - # Use default config with embedded model files - config = create_argparser().parse_args([]) - config.tts_checkpoint = checkpoint_file - config.tts_config = config_file - synthesizer = Synthesizer(config) +embedded_tts_folder = os.path.join(embedded_models_folder, 'tts') +tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar') +tts_config_file = os.path.join(embedded_tts_folder, 'config.json') +embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn') +wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar') +wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json') + +embedded_pwgan_folder = os.path.join(embedded_models_folder, 'pwgan') +pwgan_checkpoint_file = os.path.join(embedded_pwgan_folder, 'checkpoint.pkl') +pwgan_config_file = os.path.join(embedded_pwgan_folder, 'config.yml') + +args = create_argparser().parse_args() + +# If these were not specified in the CLI args, use default values with embedded model files +if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file): + args.tts_checkpoint = tts_checkpoint_file +if not args.tts_config and os.path.isfile(tts_config_file): + args.tts_config = tts_config_file +if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file): + args.wavernn_file = wavernn_checkpoint_file +if not args.wavernn_config and os.path.isfile(wavernn_config_file): + args.wavernn_config = wavernn_config_file +if not args.pwgan_file and os.path.isfile(pwgan_checkpoint_file): + args.pwgan_file = pwgan_checkpoint_file +if not args.pwgan_config and os.path.isfile(pwgan_config_file): + args.pwgan_config = pwgan_config_file + +synthesizer = Synthesizer(args) app = Flask(__name__) @@ -55,11 +77,4 @@ def tts(): if __name__ == '__main__': - args = create_argparser().parse_args() - - # Setup synthesizer from CLI args if they're specified or no embedded model - # is present. - if not config or not synthesizer or args.tts_checkpoint or args.tts_config: - synthesizer = Synthesizer(args) - - app.run(debug=config.debug, host='0.0.0.0', port=config.port) + app.run(debug=args.debug, host='0.0.0.0', port=args.port) diff --git a/server/synthesizer.py b/server/synthesizer.py index d8852a3e..e9205bf1 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -1,17 +1,20 @@ import io -import os +import re +import sys import numpy as np import torch -import sys +import yaml from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import load_config, setup_model -from TTS.utils.text import phonemes, symbols from TTS.utils.speakers import load_speaker_mapping +# pylint: disable=unused-wildcard-import +# pylint: disable=wildcard-import from TTS.utils.synthesis import * -import re +from TTS.utils.text import make_symbols, phonemes, symbols + alphabets = r"([A-Za-z])" prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]" suffixes = r"(Inc|Ltd|Jr|Sr|Co)" @@ -23,6 +26,7 @@ websites = r"[.](com|net|org|io|gov)" class Synthesizer(object): def __init__(self, config): self.wavernn = None + self.pwgan = None self.config = config self.use_cuda = self.config.use_cuda if self.use_cuda: @@ -30,28 +34,38 @@ class Synthesizer(object): self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.config.use_cuda) if self.config.wavernn_lib_path: - self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path, - self.config.wavernn_file, self.config.wavernn_config, - self.config.use_cuda) + self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file, + self.config.wavernn_config, self.config.use_cuda) + if self.config.pwgan_lib_path: + self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file, + self.config.pwgan_config, self.config.use_cuda) def load_tts(self, tts_checkpoint, tts_config, use_cuda): + # pylint: disable=global-statement + global symbols, phonemes + print(" > Loading TTS model ...") print(" | > model config: ", tts_config) print(" | > checkpoint file: ", tts_checkpoint) + self.tts_config = load_config(tts_config) self.use_phonemes = self.tts_config.use_phonemes self.ap = AudioProcessor(**self.tts_config.audio) + + if 'characters' in self.tts_config.keys(): + symbols, phonemes = make_symbols(**self.tts_config.characters) + if self.use_phonemes: self.input_size = len(phonemes) else: self.input_size = len(symbols) - # load speakers + # TODO: fix this for multi-speaker model - load speakers if self.config.tts_speakers is not None: - self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers)) + self.tts_speakers = load_speaker_mapping(self.config.tts_speakers) num_speakers = len(self.tts_speakers) else: num_speakers = 0 - self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) + self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) # load model state cp = torch.load(tts_checkpoint, map_location=torch.device('cpu')) # load the model @@ -63,16 +77,17 @@ class Synthesizer(object): if 'r' in cp: self.tts_model.decoder.set_r(cp['r']) - def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): + def load_wavernn(self, lib_path, model_file, model_config, use_cuda): # TODO: set a function in wavernn code base for model setup and call it here. - sys.path.append(lib_path) # set this if TTS is not installed globally + sys.path.append(lib_path) # set this if WaveRNN is not installed globally + #pylint: disable=import-outside-toplevel from WaveRNN.models.wavernn import Model - wavernn_config = os.path.join(model_path, model_config) - model_file = os.path.join(model_path, model_file) print(" > Loading WaveRNN model ...") - print(" | > model config: ", wavernn_config) + print(" | > model config: ", model_config) print(" | > model file: ", model_file) - self.wavernn_config = load_config(wavernn_config) + self.wavernn_config = load_config(model_config) + # This is the default architecture we use for our models. + # You might need to update it self.wavernn = Model( rnn_dims=512, fc_dims=512, @@ -80,7 +95,7 @@ class Synthesizer(object): mulaw=self.wavernn_config.mulaw, pad=self.wavernn_config.pad, use_aux_net=self.wavernn_config.use_aux_net, - use_upsample_net = self.wavernn_config.use_upsample_net, + use_upsample_net=self.wavernn_config.use_upsample_net, upsample_factors=self.wavernn_config.upsample_factors, feat_dims=80, compute_dims=128, @@ -90,19 +105,36 @@ class Synthesizer(object): sample_rate=self.ap.sample_rate, ).cuda() - check = torch.load(model_file) + check = torch.load(model_file, map_location="cpu") self.wavernn.load_state_dict(check['model']) if use_cuda: self.wavernn.cuda() self.wavernn.eval() + def load_pwgan(self, lib_path, model_file, model_config, use_cuda): + sys.path.append(lib_path) # set this if ParallelWaveGAN is not installed globally + #pylint: disable=import-outside-toplevel + from parallel_wavegan.models import ParallelWaveGANGenerator + print(" > Loading PWGAN model ...") + print(" | > model config: ", model_config) + print(" | > model file: ", model_file) + with open(model_config) as f: + self.pwgan_config = yaml.load(f, Loader=yaml.Loader) + self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"]) + self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"]) + self.pwgan.remove_weight_norm() + if use_cuda: + self.pwgan.cuda() + self.pwgan.eval() + def save_wav(self, wav, path): # wav *= 32767 / max(1e-8, np.max(np.abs(wav))) wav = np.array(wav) self.ap.save_wav(wav, path) - def split_into_sentences(self, text): - text = " " + text + " " + @staticmethod + def split_into_sentences(text): + text = " " + text + " " text = text.replace("\n", " ") text = re.sub(prefixes, "\\1", text) text = re.sub(websites, "\\1", text) @@ -129,15 +161,13 @@ class Synthesizer(object): text = text.replace("", ".") sentences = text.split("") sentences = sentences[:-1] - sentences = [s.strip() for s in sentences] + sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences return sentences def tts(self, text): wavs = [] sens = self.split_into_sentences(text) print(sens) - if not sens: - sens = [text+'.'] for sen in sens: # preprocess the given text inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda) @@ -148,9 +178,16 @@ class Synthesizer(object): postnet_output, decoder_output, _ = parse_outputs( postnet_output, decoder_output, alignments) - if self.wavernn: - postnet_output = postnet_output[0].data.cpu().numpy() - wav = self.wavernn.generate(torch.FloatTensor(postnet_output.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550) + if self.pwgan: + vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.use_cuda: + vocoder_input.cuda() + wav = self.pwgan.inference(vocoder_input, hop_size=self.ap.hop_length) + elif self.wavernn: + vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.use_cuda: + vocoder_input.cuda() + wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550) else: wav = inv_spectrogram(postnet_output, self.ap, self.tts_config) # trim silence diff --git a/setup.py b/setup.py index 63782800..f92dac8a 100644 --- a/setup.py +++ b/setup.py @@ -61,10 +61,11 @@ package_data = ['server/templates/*'] if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config: print('Embedding model in wheel file...') model_dir = os.path.join('server', 'model') - os.makedirs(model_dir, exist_ok=True) - embedded_checkpoint_path = os.path.join(model_dir, 'checkpoint.pth.tar') + tts_dir = os.path.join(model_dir, 'tts') + os.makedirs(tts_dir, exist_ok=True) + embedded_checkpoint_path = os.path.join(tts_dir, 'checkpoint.pth.tar') shutil.copy(args.checkpoint, embedded_checkpoint_path) - embedded_config_path = os.path.join(model_dir, 'config.json') + embedded_config_path = os.path.join(tts_dir, 'config.json') shutil.copy(args.model_config, embedded_config_path) package_data.extend([embedded_checkpoint_path, embedded_config_path]) diff --git a/synthesize.py b/synthesize.py index a338f8b8..1f1ce36f 100644 --- a/synthesize.py +++ b/synthesize.py @@ -1,3 +1,4 @@ +# pylint: disable=redefined-outer-name, unused-argument import os import time import argparse @@ -7,7 +8,7 @@ import string from TTS.utils.synthesis import synthesis from TTS.utils.generic_utils import load_config, setup_model -from TTS.utils.text.symbols import symbols, phonemes +from TTS.utils.text.symbols import make_symbols, symbols, phonemes from TTS.utils.audio import AudioProcessor @@ -47,6 +48,8 @@ def tts(model, if __name__ == "__main__": + global symbols, phonemes + parser = argparse.ArgumentParser() parser.add_argument('text', type=str, help='Text to generate speech.') parser.add_argument('config_path', @@ -104,6 +107,10 @@ if __name__ == "__main__": # load the audio processor ap = AudioProcessor(**C.audio) + # if the vocabulary was passed, replace the default + if 'characters' in C.keys(): + symbols, phonemes = make_symbols(**C.characters) + # load speakers if args.speakers_json != '': speakers = json.load(open(args.speakers_json, 'r')) diff --git a/tests/inputs/server_config.json b/tests/inputs/server_config.json index 3988db4c..7f5a60fb 100644 --- a/tests/inputs/server_config.json +++ b/tests/inputs/server_config.json @@ -3,9 +3,11 @@ "tts_config":"dummy_model_config.json", // tts config.json file "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. "wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis. - "wavernn_path": null, // wavernn model root path "wavernn_file": null, // wavernn checkpoint file name "wavernn_config": null, // wavernn config file + "pwgan_lib_path": null, + "pwgan_file": null, + "pwgan_config": null, "is_wavernn_batched":true, "port": 5002, "use_cuda": false, diff --git a/tests/test_config.json b/tests/test_config.json index 0cd3d751..6d63e6ab 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -19,6 +19,16 @@ "mel_fmax": 7600, // maximum freq level for mel-spec. Tune for dataset!! "do_trim_silence": false }, + + "characters":{ + "pad": "_", + "eos": "~", + "bos": "^", + "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + "punctuations":"!'(),-.:;? ", + "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + }, + "hidden_size": 128, "embedding_size": 256, "text_cleaner": "english_cleaners", diff --git a/tests/test_demo_server.py b/tests/test_demo_server.py index c343a6a4..a0837686 100644 --- a/tests/test_demo_server.py +++ b/tests/test_demo_server.py @@ -5,13 +5,19 @@ import torch as T from TTS.server.synthesizer import Synthesizer from TTS.tests import get_tests_input_path, get_tests_output_path -from TTS.utils.text.symbols import phonemes, symbols +from TTS.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model class DemoServerTest(unittest.TestCase): + # pylint: disable=R0201 def _create_random_model(self): + # pylint: disable=global-statement + global symbols, phonemes config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json')) + if 'characters' in config.keys(): + symbols, phonemes = make_symbols(**config.characters) + num_chars = len(phonemes) if config.use_phonemes else len(symbols) model = setup_model(num_chars, 0, config) output_path = os.path.join(get_tests_output_path()) diff --git a/tests/test_layers.py b/tests/test_layers.py index 7d02b673..d7c8829f 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -131,7 +131,7 @@ class L1LossMaskedTests(unittest.TestCase): dummy_target = T.zeros(4, 8, 128).float() dummy_length = (T.ones(4) * 8).long() output = layer(dummy_input, dummy_target, dummy_length) - assert output.item() == 1.0, "1.0 vs {}".format(output.data[0]) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) # test if padded values of input makes any difference dummy_input = T.ones(4, 8, 128).float() @@ -140,7 +140,7 @@ class L1LossMaskedTests(unittest.TestCase): mask = ( (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) - assert output.item() == 1.0, "1.0 vs {}".format(output.data[0]) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) dummy_input = T.rand(4, 8, 128).float() dummy_target = dummy_input.detach() @@ -148,4 +148,37 @@ class L1LossMaskedTests(unittest.TestCase): mask = ( (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) output = layer(dummy_input + mask, dummy_target, dummy_length) - assert output.item() == 0, "0 vs {}".format(output.data[0]) + assert output.item() == 0, "0 vs {}".format(output.item()) + + # seq_len_norm = True + # test input == target + layer = L1LossMasked(seq_len_norm=True) + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.ones(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 0.0 + + # test input != target + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.ones(4) * 8).long() + output = layer(dummy_input, dummy_target, dummy_length) + assert output.item() == 1.0, "1.0 vs {}".format(output.item()) + + # test if padded values of input makes any difference + dummy_input = T.ones(4, 8, 128).float() + dummy_target = T.zeros(4, 8, 128).float() + dummy_length = (T.arange(5, 9)).long() + mask = ( + (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert abs(output.item() - 1.0) < 1e-5, "1.0 vs {}".format(output.item()) + + dummy_input = T.rand(4, 8, 128).float() + dummy_target = dummy_input.detach() + dummy_length = (T.arange(5, 9)).long() + mask = ( + (sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2) + output = layer(dummy_input + mask, dummy_target, dummy_length) + assert output.item() == 0, "0 vs {}".format(output.item()) diff --git a/tests/test_loader.py b/tests/test_loader.py index 751bc181..d835c5d3 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -37,7 +37,8 @@ class TestTTSDataset(unittest.TestCase): r, c.text_cleaner, ap=self.ap, - meta_data=items, + meta_data=items, + tp=c.characters if 'characters' in c.keys() else None, batch_group_size=bgs, min_seq_len=c.min_seq_len, max_seq_len=float("inf"), @@ -137,9 +138,7 @@ class TestTTSDataset(unittest.TestCase): # NOTE: Below needs to check == 0 but due to an unknown reason # there is a slight difference between two matrices. # TODO: Check this assert cond more in detail. - assert abs((abs(mel.T) - - abs(mel_dl) - ).sum()) < 1e-5, (abs(mel.T) - abs(mel_dl)).sum() + assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max() # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() diff --git a/tests/test_server_package.sh b/tests/test_server_package.sh index 01e42843..9fe5e8b1 100755 --- a/tests/test_server_package.sh +++ b/tests/test_server_package.sh @@ -11,7 +11,7 @@ source /tmp/venv/bin/activate pip install --quiet --upgrade pip setuptools wheel rm -f dist/*.whl -python setup.py bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json +python setup.py --quiet bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json pip install --quiet dist/TTS*.whl python -m TTS.server.server & diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index 8f8e6fab..6c0c7058 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -1,7 +1,14 @@ +import os +# pylint: disable=unused-wildcard-import +# pylint: disable=wildcard-import +# pylint: disable=unused-import import unittest -import torch as T - from TTS.utils.text import * +from TTS.tests import get_tests_path +from TTS.utils.generic_utils import load_config + +TESTS_PATH = get_tests_path() +conf = load_config(os.path.join(TESTS_PATH, 'test_config.json')) def test_phoneme_to_sequence(): text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" @@ -9,67 +16,80 @@ def test_phoneme_to_sequence(): lang = "en-us" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # multiple punctuations text = "Be a voice, not an! echo?" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # not ending with punctuation text = "Be a voice, not an! echo" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # original text = "Be a voice, not an echo!" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # extra space after the sentence text = "Be a voice, not an! echo. " sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # extra space after the sentence text = "Be a voice, not an! echo. " sequence = phoneme_to_sequence(text, text_cleaner, lang, True) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~" print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # padding char text = "_Be a _voice, not an! echo_" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" print(text_hat) print(len(sequence)) - assert text_hat == gt - + assert text_hat == text_hat_with_params == gt def test_text2phone(): text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" - gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i|| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n||| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!" + gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!" lang = "en-us" - phonemes = text2phone(text, lang) - assert gt == phonemes + ph = text2phone(text, lang) + assert gt == ph, f"\n{phonemes} \n vs \n{gt}" \ No newline at end of file diff --git a/train.py b/train.py index 81bc2c72..0aa3f748 100644 --- a/train.py +++ b/train.py @@ -13,19 +13,19 @@ from torch.utils.data import DataLoader from TTS.datasets.TTSDataset import MyDataset from distribute import (DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor) -from TTS.layers.losses import L1LossMasked, MSELossMasked +from TTS.layers.losses import L1LossMasked, MSELossMasked, BCELossMasked from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import ( NoamLR, check_update, count_parameters, create_experiment_folder, get_git_branch, load_config, remove_experiment_folder, save_best_model, save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file, setup_model, gradual_training_scheduler, KeepAverage, - set_weight_decay) + set_weight_decay, check_config) from TTS.utils.logger import Logger from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ get_speakers from TTS.utils.synthesis import synthesis -from TTS.utils.text.symbols import phonemes, symbols +from TTS.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.visual import plot_alignment, plot_spectrogram from TTS.datasets.preprocess import load_meta_data from TTS.utils.radam import RAdam @@ -49,6 +49,7 @@ def setup_loader(ap, r, is_val=False, verbose=False): c.text_cleaner, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, + tp=c.characters if 'characters' in c.keys() else None, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, @@ -167,7 +168,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, # loss computation stop_loss = criterion_st(stop_tokens, - stop_targets) if c.stopnet else torch.zeros(1) + stop_targets, mel_lengths) if c.stopnet else torch.zeros(1) if c.loss_masking: decoder_loss = criterion(decoder_output, mel_input, mel_lengths) if c.model in ["Tacotron", "TacotronGST"]: @@ -327,6 +328,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, return keep_avg['avg_postnet_loss'], global_step +@torch.no_grad() def evaluate(model, criterion, criterion_st, ap, global_step, epoch): data_loader = setup_loader(ap, model.decoder.r, is_val=True) if c.use_speaker_embedding: @@ -346,125 +348,124 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): keep_avg.add_values(eval_values_dict) print("\n > Validation") - with torch.no_grad(): - if data_loader is not None: - for num_iter, data in enumerate(data_loader): - start_time = time.time() + if data_loader is not None: + for num_iter, data in enumerate(data_loader): + start_time = time.time() - # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data) - assert mel_input.shape[1] % model.decoder.r == 0 + # format data + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data) + assert mel_input.shape[1] % model.decoder.r == 0 - # forward pass model - if c.bidirectional_decoder: - decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) - else: - decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + # forward pass model + if c.bidirectional_decoder: + decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( + text_input, text_lengths, mel_input, speaker_ids=speaker_ids) + else: + decoder_output, postnet_output, alignments, stop_tokens = model( + text_input, text_lengths, mel_input, speaker_ids=speaker_ids) - # loss computation - stop_loss = criterion_st( - stop_tokens, stop_targets) if c.stopnet else torch.zeros(1) - if c.loss_masking: - decoder_loss = criterion(decoder_output, mel_input, - mel_lengths) - if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input, - mel_lengths) - else: - postnet_loss = criterion(postnet_output, mel_input, - mel_lengths) - else: - decoder_loss = criterion(decoder_output, mel_input) - if c.model in ["Tacotron", "TacotronGST"]: - postnet_loss = criterion(postnet_output, linear_input) - else: - postnet_loss = criterion(postnet_output, mel_input) - loss = decoder_loss + postnet_loss + stop_loss - - # backward decoder loss - if c.bidirectional_decoder: - if c.loss_masking: - decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths) - else: - decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input) - decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output) - loss += decoder_backward_loss + decoder_c_loss - keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()}) - - step_time = time.time() - start_time - epoch_time += step_time - - # compute alignment score - align_score = alignment_diagonal_score(alignments) - keep_avg.update_value('avg_align_score', align_score) - - # aggregate losses from processes - if num_gpus > 1: - postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) - decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) - if c.stopnet: - stop_loss = reduce_tensor(stop_loss.data, num_gpus) - - keep_avg.update_values({ - 'avg_postnet_loss': - float(postnet_loss.item()), - 'avg_decoder_loss': - float(decoder_loss.item()), - 'avg_stop_loss': - float(stop_loss.item()), - }) - - if num_iter % c.print_step == 0: - print( - " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} " - "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}" - .format(loss.item(), postnet_loss.item(), - keep_avg['avg_postnet_loss'], - decoder_loss.item(), - keep_avg['avg_decoder_loss'], stop_loss.item(), - keep_avg['avg_stop_loss'], align_score, - keep_avg['avg_align_score']), - flush=True) - - if args.rank == 0: - # Diagnostic visualizations - idx = np.random.randint(mel_input.shape[0]) - const_spec = postnet_output[idx].data.cpu().numpy() - gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ - "Tacotron", "TacotronGST" - ] else mel_input[idx].data.cpu().numpy() - align_img = alignments[idx].data.cpu().numpy() - - eval_figures = { - "prediction": plot_spectrogram(const_spec, ap), - "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img) - } - - # Sample audio + # loss computation + stop_loss = criterion_st( + stop_tokens, stop_targets, mel_lengths) if c.stopnet else torch.zeros(1) + if c.loss_masking: + decoder_loss = criterion(decoder_output, mel_input, + mel_lengths) if c.model in ["Tacotron", "TacotronGST"]: - eval_audio = ap.inv_spectrogram(const_spec.T) + postnet_loss = criterion(postnet_output, linear_input, + mel_lengths) else: - eval_audio = ap.inv_mel_spectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + postnet_loss = criterion(postnet_output, mel_input, + mel_lengths) + else: + decoder_loss = criterion(decoder_output, mel_input) + if c.model in ["Tacotron", "TacotronGST"]: + postnet_loss = criterion(postnet_output, linear_input) + else: + postnet_loss = criterion(postnet_output, mel_input) + loss = decoder_loss + postnet_loss + stop_loss - # Plot Validation Stats - epoch_stats = { - "loss_postnet": keep_avg['avg_postnet_loss'], - "loss_decoder": keep_avg['avg_decoder_loss'], - "stop_loss": keep_avg['avg_stop_loss'], - "alignment_score": keep_avg['avg_align_score'] - } + # backward decoder loss + if c.bidirectional_decoder: + if c.loss_masking: + decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths) + else: + decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output) + loss += decoder_backward_loss + decoder_c_loss + keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()}) - if c.bidirectional_decoder: - epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss'] - align_b_img = alignments_backward[idx].data.cpu().numpy() - eval_figures['alignment_backward'] = plot_alignment(align_b_img) - tb_logger.tb_eval_stats(global_step, epoch_stats) - tb_logger.tb_eval_figures(global_step, eval_figures) + step_time = time.time() - start_time + epoch_time += step_time + + # compute alignment score + align_score = alignment_diagonal_score(alignments) + keep_avg.update_value('avg_align_score', align_score) + + # aggregate losses from processes + if num_gpus > 1: + postnet_loss = reduce_tensor(postnet_loss.data, num_gpus) + decoder_loss = reduce_tensor(decoder_loss.data, num_gpus) + if c.stopnet: + stop_loss = reduce_tensor(stop_loss.data, num_gpus) + + keep_avg.update_values({ + 'avg_postnet_loss': + float(postnet_loss.item()), + 'avg_decoder_loss': + float(decoder_loss.item()), + 'avg_stop_loss': + float(stop_loss.item()), + }) + + if num_iter % c.print_step == 0: + print( + " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} " + "StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}" + .format(loss.item(), postnet_loss.item(), + keep_avg['avg_postnet_loss'], + decoder_loss.item(), + keep_avg['avg_decoder_loss'], stop_loss.item(), + keep_avg['avg_stop_loss'], align_score, + keep_avg['avg_align_score']), + flush=True) + + if args.rank == 0: + # Diagnostic visualizations + idx = np.random.randint(mel_input.shape[0]) + const_spec = postnet_output[idx].data.cpu().numpy() + gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ + "Tacotron", "TacotronGST" + ] else mel_input[idx].data.cpu().numpy() + align_img = alignments[idx].data.cpu().numpy() + + eval_figures = { + "prediction": plot_spectrogram(const_spec, ap), + "ground_truth": plot_spectrogram(gt_spec, ap), + "alignment": plot_alignment(align_img) + } + + # Sample audio + if c.model in ["Tacotron", "TacotronGST"]: + eval_audio = ap.inv_spectrogram(const_spec.T) + else: + eval_audio = ap.inv_mel_spectrogram(const_spec.T) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, + c.audio["sample_rate"]) + + # Plot Validation Stats + epoch_stats = { + "loss_postnet": keep_avg['avg_postnet_loss'], + "loss_decoder": keep_avg['avg_decoder_loss'], + "stop_loss": keep_avg['avg_stop_loss'], + "alignment_score": keep_avg['avg_align_score'] + } + + if c.bidirectional_decoder: + epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss'] + align_b_img = alignments_backward[idx].data.cpu().numpy() + eval_figures['alignment_backward'] = plot_alignment(align_b_img) + tb_logger.tb_eval_stats(global_step, epoch_stats) + tb_logger.tb_eval_figures(global_step, eval_figures) if args.rank == 0 and epoch > c.test_delay_epochs: if c.test_sentences_file is None: @@ -493,7 +494,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): use_cuda, ap, speaker_id=speaker_id, - style_wav=style_wav) + style_wav=style_wav, + truncated=False, + enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + use_griffin_lim=True, + do_trim_silence=False) + file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) file_path = os.path.join(file_path, @@ -515,9 +521,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name - global meta_data_train, meta_data_eval + # pylint: disable=global-variable-undefined + global meta_data_train, meta_data_eval, symbols, phonemes # Audio processor ap = AudioProcessor(**c.audio) + if 'characters' in c.keys(): + symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: @@ -561,12 +570,12 @@ def main(args): # pylint: disable=redefined-outer-name optimizer_st = None if c.loss_masking: - criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST" - ] else MSELossMasked() + criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron", "TacotronGST" + ] else MSELossMasked(c.seq_len_norm) else: criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST" ] else nn.MSELoss() - criterion_st = nn.BCEWithLogitsLoss( + criterion_st = BCELossMasked( pos_weight=torch.tensor(10)) if c.stopnet else None if args.restore_path: @@ -687,6 +696,7 @@ if __name__ == '__main__': # setup output paths and read configs c = load_config(args.config_path) + check_config(c) _ = os.path.dirname(os.path.realpath(__file__)) OUT_PATH = args.continue_path diff --git a/utils/audio.py b/utils/audio.py index 708f0853..771e6a43 100644 --- a/utils/audio.py +++ b/utils/audio.py @@ -12,6 +12,8 @@ class AudioProcessor(object): min_level_db=None, frame_shift_ms=None, frame_length_ms=None, + hop_length=None, + win_length=None, ref_level_db=None, num_freq=None, power=None, @@ -24,6 +26,7 @@ class AudioProcessor(object): clip_norm=True, griffin_lim_iters=None, do_trim_silence=False, + trim_db=60, sound_norm=False, **_): @@ -46,8 +49,14 @@ class AudioProcessor(object): self.max_norm = 1.0 if max_norm is None else float(max_norm) self.clip_norm = clip_norm self.do_trim_silence = do_trim_silence + self.trim_db = trim_db self.sound_norm = sound_norm - self.n_fft, self.hop_length, self.win_length = self._stft_parameters() + if hop_length is None: + self.n_fft, self.hop_length, self.win_length = self._stft_parameters() + else: + self.hop_length = hop_length + self.win_length = win_length + self.n_fft = (self.num_freq - 1) * 2 assert min_level_db != 0.0, " [!] min_level_db is 0" members = vars(self) for key, value in members.items(): @@ -66,12 +75,11 @@ class AudioProcessor(object): return np.maximum(1e-10, np.dot(inv_mel_basis, mel_spec)) def _build_mel_basis(self, ): - n_fft = (self.num_freq - 1) * 2 if self.mel_fmax is not None: assert self.mel_fmax <= self.sample_rate // 2 return librosa.filters.mel( self.sample_rate, - n_fft, + self.n_fft, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax) @@ -197,6 +205,7 @@ class AudioProcessor(object): n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, + pad_mode='constant' ) def _istft(self, y): @@ -217,7 +226,7 @@ class AudioProcessor(object): margin = int(self.sample_rate * 0.01) wav = wav[margin:-margin] return librosa.effects.trim( - wav, top_db=60, frame_length=self.win_length, hop_length=self.hop_length)[0] + wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0] @staticmethod def mulaw_encode(wav, qc): diff --git a/utils/data.py b/utils/data.py index 87343ec1..f2d7538a 100644 --- a/utils/data.py +++ b/utils/data.py @@ -14,7 +14,7 @@ def prepare_data(inputs): def _pad_tensor(x, length): - _pad = 0 + _pad = 0. assert x.ndim == 2 x = np.pad( x, [[0, 0], [0, length - x.shape[1]]], @@ -31,7 +31,7 @@ def prepare_tensor(inputs, out_steps): def _pad_stop_target(x, length): - _pad = 1. + _pad = 0. assert x.ndim == 1 return np.pad( x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index cf1a83a6..f6c38530 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -389,3 +389,133 @@ class KeepAverage(): def update_values(self, value_dict): for key, value in value_dict.items(): self.update_value(key, value) + + +def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None): + if alternative in c.keys() and c[alternative] is not None: + return + if restricted: + assert name in c.keys(), f' [!] {name} not defined in config.json' + if name in c.keys(): + if max_val: + assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}' + if min_val: + assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' + if enum_list: + assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' + if val_type: + assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + + +def check_config(c): + _check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) + _check_argument('run_name', c, restricted=True, val_type=str) + _check_argument('run_description', c, val_type=str) + + # AUDIO + _check_argument('audio', c, restricted=True, val_type=dict) + + # audio processing parameters + _check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) + _check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) + _check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) + _check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') + _check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') + _check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) + _check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) + _check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) + _check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) + _check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + + # vocabulary parameters + _check_argument('characters', c, restricted=False, val_type=dict) + _check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + + # normalization parameters + _check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) + _check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) + _check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) + _check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) + _check_argument('trim_db', c['audio'], restricted=True, val_type=int) + + # training parameters + _check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) + _check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) + _check_argument('r', c, restricted=True, val_type=int, min_val=1) + _check_argument('gradual_training', c, restricted=False, val_type=list) + _check_argument('loss_masking', c, restricted=True, val_type=bool) + # _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) + + # validation parameters + _check_argument('run_eval', c, restricted=True, val_type=bool) + _check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) + _check_argument('test_sentences_file', c, restricted=False, val_type=str) + + # optimizer + _check_argument('noam_schedule', c, restricted=False, val_type=bool) + _check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) + _check_argument('epochs', c, restricted=True, val_type=int, min_val=1) + _check_argument('lr', c, restricted=True, val_type=float, min_val=0) + _check_argument('wd', c, restricted=True, val_type=float, min_val=0) + _check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) + _check_argument('seq_len_norm', c, restricted=True, val_type=bool) + + # tacotron prenet + _check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1) + _check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn']) + _check_argument('prenet_dropout', c, restricted=True, val_type=bool) + + # attention + _check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original']) + _check_argument('attention_heads', c, restricted=True, val_type=int) + _check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax']) + _check_argument('windowing', c, restricted=True, val_type=bool) + _check_argument('use_forward_attn', c, restricted=True, val_type=bool) + _check_argument('forward_attn_mask', c, restricted=True, val_type=bool) + _check_argument('transition_agent', c, restricted=True, val_type=bool) + _check_argument('transition_agent', c, restricted=True, val_type=bool) + _check_argument('location_attn', c, restricted=True, val_type=bool) + _check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) + + # stopnet + _check_argument('stopnet', c, restricted=True, val_type=bool) + _check_argument('separate_stopnet', c, restricted=True, val_type=bool) + + # tensorboard + _check_argument('print_step', c, restricted=True, val_type=int, min_val=1) + _check_argument('save_step', c, restricted=True, val_type=int, min_val=1) + _check_argument('checkpoint', c, restricted=True, val_type=bool) + _check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + + # dataloading + _check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=['english_cleaners', 'phoneme_cleaners', 'transliteration_cleaners', 'basic_cleaners']) + _check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) + _check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) + _check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) + _check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) + _check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) + _check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) + + # paths + _check_argument('output_path', c, restricted=True, val_type=str) + + # multi-speaker gst + _check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) + _check_argument('style_wav_for_test', c, restricted=True, val_type=str) + _check_argument('use_gst', c, restricted=True, val_type=bool) + + # datasets - checking only the first entry + _check_argument('datasets', c, restricted=True, val_type=list) + for dataset_entry in c['datasets']: + _check_argument('name', dataset_entry, restricted=True, val_type=str) + _check_argument('path', dataset_entry, restricted=True, val_type=str) + _check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str) + _check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) \ No newline at end of file diff --git a/utils/synthesis.py b/utils/synthesis.py index 79a17c78..1047c16b 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -9,10 +9,11 @@ def text_to_seqvec(text, CONFIG, use_cuda): if CONFIG.use_phonemes: seq = np.asarray( phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars), + CONFIG.enable_eos_bos_chars, + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) else: - seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) + seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) # torch tensor chars_var = torch.from_numpy(seq).unsqueeze(0) if use_cuda: @@ -69,6 +70,24 @@ def id_to_torch(speaker_id): return speaker_id +# TODO: perform GL with pytorch for batching +def apply_griffin_lim(inputs, input_lens, CONFIG, ap): + '''Apply griffin-lim to each sample iterating throught the first dimension. + Args: + inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size. + input_lens (Tensor or np.Array): 1D array of sample lengths. + CONFIG (Dict): TTS config. + ap (AudioProcessor): TTS audio processor. + ''' + wavs = [] + for idx, spec in enumerate(inputs): + wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding + wav = inv_spectrogram(spec, ap, CONFIG) + # assert len(wav) == wav_len, f" [!] wav lenght: {len(wav)} vs expected: {wav_len}" + wavs.append(wav[:wav_len]) + return wavs + + def synthesis(model, text, CONFIG, diff --git a/utils/text/__init__.py b/utils/text/__init__.py index 1c5b98c3..79069192 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -1,18 +1,19 @@ # -*- coding: utf-8 -*- import re +from packaging import version import phonemizer from phonemizer.phonemize import phonemize from TTS.utils.text import cleaners -from TTS.utils.text.symbols import symbols, phonemes, _phoneme_punctuations, _bos, \ +from TTS.utils.text.symbols import make_symbols, symbols, phonemes, _phoneme_punctuations, _bos, \ _eos # Mappings from symbol to numeric ID and vice versa: -_SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)} -_ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)} +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} -_PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)} -_ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)} +_phonemes_to_id = {s: i for i, s in enumerate(phonemes)} +_id_to_phonemes = {i: s for i, s in enumerate(phonemes)} # Regular expression matching text enclosed in curly braces: _CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)') @@ -28,29 +29,53 @@ def text2phone(text, language): seperator = phonemizer.separator.Separator(' |', '', '|') #try: punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text) - ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language) - ph = ph[:-1].strip() # skip the last empty character - # Replace \n with matching punctuations. - if punctuations: - # if text ends with a punctuation. - if text[-1] == punctuations[-1]: - for punct in punctuations[:-1]: - ph = ph.replace('| |\n', '|'+punct+'| |', 1) - try: - ph = ph + punctuations[-1] - except: - print(text) - else: - for punct in punctuations: - ph = ph.replace('| |\n', '|'+punct+'| |', 1) + if version.parse(phonemizer.__version__) < version.parse('2.1'): + ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language) + ph = ph[:-1].strip() # skip the last empty character + # phonemizer does not tackle punctuations. Here we do. + # Replace \n with matching punctuations. + if punctuations: + # if text ends with a punctuation. + if text[-1] == punctuations[-1]: + for punct in punctuations[:-1]: + ph = ph.replace('| |\n', '|'+punct+'| |', 1) + ph = ph + punctuations[-1] + else: + for punct in punctuations: + ph = ph.replace('| |\n', '|'+punct+'| |', 1) + elif version.parse(phonemizer.__version__) >= version.parse('2.1'): + ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True) + # this is a simple fix for phonemizer. + # https://github.com/bootphon/phonemizer/issues/32 + if punctuations: + for punctuation in punctuations: + ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace(f"| |{punctuation}", f"|{punctuation}| |") + ph = ph[:-3] + else: + raise RuntimeError(" [!] Use 'phonemizer' version 2.1 or older.") + return ph -def pad_with_eos_bos(phoneme_sequence): - return [_PHONEMES_TO_ID[_bos]] + list(phoneme_sequence) + [_PHONEMES_TO_ID[_eos]] +def pad_with_eos_bos(phoneme_sequence, tp=None): + # pylint: disable=global-statement + global _phonemes_to_id, _bos, _eos + if tp: + _bos = tp['bos'] + _eos = tp['eos'] + _, _phonemes = make_symbols(**tp) + _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} + + return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] -def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False): +def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None): + # pylint: disable=global-statement + global _phonemes_to_id + if tp: + _, _phonemes = make_symbols(**tp) + _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} + sequence = [] text = text.replace(":", "") clean_text = _clean_text(text, cleaner_names) @@ -62,21 +87,27 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False): sequence += _phoneme_to_sequence(phoneme) # Append EOS char if enable_eos_bos: - sequence = pad_with_eos_bos(sequence) + sequence = pad_with_eos_bos(sequence, tp=tp) return sequence -def sequence_to_phoneme(sequence): +def sequence_to_phoneme(sequence, tp=None): + # pylint: disable=global-statement '''Converts a sequence of IDs back to a string''' + global _id_to_phonemes result = '' + if tp: + _, _phonemes = make_symbols(**tp) + _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} + for symbol_id in sequence: - if symbol_id in _ID_TO_PHONEMES: - s = _ID_TO_PHONEMES[symbol_id] + if symbol_id in _id_to_phonemes: + s = _id_to_phonemes[symbol_id] result += s return result.replace('}{', ' ') -def text_to_sequence(text, cleaner_names): +def text_to_sequence(text, cleaner_names, tp=None): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded @@ -89,6 +120,12 @@ def text_to_sequence(text, cleaner_names): Returns: List of integers corresponding to the symbols in the text ''' + # pylint: disable=global-statement + global _symbol_to_id + if tp: + _symbols, _ = make_symbols(**tp) + _symbol_to_id = {s: i for i, s in enumerate(_symbols)} + sequence = [] # Check for curly braces and treat their contents as ARPAbet: while text: @@ -103,12 +140,18 @@ def text_to_sequence(text, cleaner_names): return sequence -def sequence_to_text(sequence): +def sequence_to_text(sequence, tp=None): '''Converts a sequence of IDs back to a string''' + # pylint: disable=global-statement + global _id_to_symbol + if tp: + _symbols, _ = make_symbols(**tp) + _id_to_symbol = {i: s for i, s in enumerate(_symbols)} + result = '' for symbol_id in sequence: - if symbol_id in _ID_TO_SYMBOL: - s = _ID_TO_SYMBOL[symbol_id] + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: if len(s) > 1 and s[0] == '@': s = '{%s}' % s[1:] @@ -126,11 +169,11 @@ def _clean_text(text, cleaner_names): def _symbols_to_sequence(syms): - return [_SYMBOL_TO_ID[s] for s in syms if _should_keep_symbol(s)] + return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)] def _phoneme_to_sequence(phons): - return [_PHONEMES_TO_ID[s] for s in list(phons) if _should_keep_phoneme(s)] + return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)] def _arpabet_to_sequence(text): @@ -138,8 +181,8 @@ def _arpabet_to_sequence(text): def _should_keep_symbol(s): - return s in _SYMBOL_TO_ID and s not in ['~', '^', '_'] + return s in _symbol_to_id and s not in ['~', '^', '_'] def _should_keep_phoneme(p): - return p in _PHONEMES_TO_ID and p not in ['~', '^', '_'] + return p in _phonemes_to_id and p not in ['~', '^', '_'] diff --git a/utils/text/cleaners.py b/utils/text/cleaners.py index 581633a2..e6b611b4 100644 --- a/utils/text/cleaners.py +++ b/utils/text/cleaners.py @@ -63,6 +63,19 @@ def convert_to_ascii(text): return unidecode(text) +def remove_aux_symbols(text): + text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text) + return text + + +def replace_symbols(text): + text = text.replace(';', ',') + text = text.replace('-', ' ') + text = text.replace(':', ' ') + text = text.replace('&', 'and') + return text + + def basic_cleaners(text): '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' text = lowercase(text) @@ -84,6 +97,8 @@ def english_cleaners(text): text = lowercase(text) text = expand_numbers(text) text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) text = collapse_whitespace(text) return text @@ -93,5 +108,7 @@ def phoneme_cleaners(text): text = convert_to_ascii(text) text = expand_numbers(text) text = expand_abbreviations(text) + text = replace_symbols(text) + text = remove_aux_symbols(text) text = collapse_whitespace(text) return text diff --git a/utils/text/symbols.py b/utils/text/symbols.py index ee6fd2cf..544277c5 100644 --- a/utils/text/symbols.py +++ b/utils/text/symbols.py @@ -5,6 +5,18 @@ Defines the set of symbols used in text input to the model. The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' +def make_symbols(characters, phonemes, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name + ''' Function to create symbols and phonemes ''' + _phonemes_sorted = sorted(list(phonemes)) + + # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): + _arpabet = ['@' + s for s in _phonemes_sorted] + + # Export all symbols: + _symbols = [pad, eos, bos] + list(characters) + _arpabet + _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) + + return _symbols, _phonemes _pad = '_' _eos = '~' @@ -20,14 +32,9 @@ _pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðsz _suprasegmentals = 'ˈˌːˑ' _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' _diacrilics = 'ɚ˞ɫ' -_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics)) +_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics -# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): -_arpabet = ['@' + s for s in _phonemes] - -# Export all symbols: -symbols = [_pad, _eos, _bos] + list(_characters) + _arpabet -phonemes = [_pad, _eos, _bos] + list(_phonemes) + list(_punctuations) +symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos) # Generate ALIEN language # from random import shuffle diff --git a/utils/visual.py b/utils/visual.py index ab513666..1cb9ac5d 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -54,9 +54,10 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON plt.xlabel("Decoder timestamp", fontsize=label_fontsize) plt.ylabel("Encoder timestamp", fontsize=label_fontsize) if CONFIG.use_phonemes: - seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars) - text = sequence_to_phoneme(seq) + seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) + text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) print(text) + plt.yticks(range(len(text)), list(text)) plt.colorbar()