diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 00000000..5bac63d3 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,19 @@ +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 60 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 7 +# Issues with these labels will never be considered stale +exemptLabels: + - pinned + - security +# Label to use when marking an issue as stale +staleLabel: wontfix +# Comment to post when marking an issue as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. Thank you + for your contributions. You might also look our discourse page for further help. + https://discourse.mozilla.org/c/tts +# Comment to post when closing a stale issue. Set to `false` to disable +closeComment: false + diff --git a/README.md b/README.md index e98be3c4..19d7fa24 100644 --- a/README.md +++ b/README.md @@ -115,10 +115,7 @@ In case of any error or intercepted execution, if there is no checkpoint yet und You can also enjoy Tensorboard, if you point Tensorboard argument```--logdir``` to the experiment folder. -## Testing -Best way to test your network is to use Notebooks under ```notebooks``` folder. - -There is also a good [CoLab](https://colab.research.google.com/github/tugstugi/dl-colab-notebooks/blob/master/notebooks/Mozilla_TTS_WaveRNN.ipynb) sample using pre-trained models (by @tugstugi). +## [Testing and Examples](https://github.com/mozilla/TTS/wiki/Examples-using-TTS) ## Contribution guidelines This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines.](https://www.mozilla.org/about/governance/policies/participation/) @@ -139,28 +136,7 @@ If you like to use TTS to try a new idea and like to share your experiments with - Share your results as you proceed. (Tensorboard log files, audio results, visuals etc.) - Use LJSpeech dataset (for English) if you like to compare results with the released models. (It is the most open scalable dataset for quick experimentation) -## Contact/Getting Help -- [Wiki](https://github.com/mozilla/TTS/wiki) - -- [Discourse Forums](https://discourse.mozilla.org/c/tts) - If your question is not addressed in the Wiki, the Discourse Forums is the next place to look. They contain conversations on General Topics, Using TTS, and TTS Development. - -- [Issues](https://github.com/mozilla/TTS/issues) - Finally, if all else fails, you can open an issue in our repo. - - +## [Contact/Getting Help](https://github.com/mozilla/TTS/wiki/Contact-and-Getting-Help) ## Major TODOs - [x] Implement the model. diff --git a/config.json b/config.json index 62cccac8..171dd5c6 100644 --- a/config.json +++ b/config.json @@ -11,6 +11,8 @@ "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. "win_length": 1024, // stft window length in ms. "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. "min_level_db": -100, // normalization range "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. @@ -19,7 +21,7 @@ // Normalization parameters "signal_norm": true, // normalize the spec values in range [0, 1] "symmetric_norm": true, // move normalization to range [-1, 1] - "max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] "clip_norm": true, // clip normalized values into the range. "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! @@ -27,6 +29,18 @@ "trim_db": 60 // threshold for timming silence. Set this according to your dataset. }, + // VOCABULARY PARAMETERS + // if custom character set is not defined, + // default set in symbols.py is used + "characters":{ + "pad": "_", + "eos": "~", + "bos": "^", + "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + "punctuations":"!'(),-.:;? ", + "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + }, + // DISTRIBUTED TRAINING "distributed":{ "backend": "nccl", @@ -36,11 +50,12 @@ "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. // TRAINING - "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "eval_batch_size":16, "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. "loss_masking": true, // enable / disable loss masking against the sequence padding. + "grad_accum": 2, // if N > 1, enable gradient accumulation for N iterations. It is useful for low memory GPUs. // VALIDATION "run_eval": true, @@ -49,7 +64,7 @@ // OPTIMIZER "noam_schedule": false, // use noam warmup and lr schedule. - "grad_clip": 1, // upper limit for gradients for clipping. + "grad_clip": 1.0, // upper limit for gradients for clipping. "epochs": 1000, // total number of epochs to train. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "wd": 0.000001, // Weight decay weight. diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index a45d77ff..d3a6f486 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -15,6 +15,7 @@ class MyDataset(Dataset): text_cleaner, ap, meta_data, + tp=None, batch_group_size=0, min_seq_len=0, max_seq_len=float("inf"), @@ -49,6 +50,7 @@ class MyDataset(Dataset): self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap + self.tp = tp self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language @@ -75,13 +77,13 @@ class MyDataset(Dataset): def _generate_and_cache_phoneme_sequence(self, text, cache_path): """generate a phoneme sequence from text. - since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" phonemes = phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language, - enable_eos_bos=False) + enable_eos_bos=False, + tp=self.tp) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @@ -101,7 +103,7 @@ class MyDataset(Dataset): phonemes = self._generate_and_cache_phoneme_sequence(text, cache_path) if self.enable_eos_bos: - phonemes = pad_with_eos_bos(phonemes) + phonemes = pad_with_eos_bos(phonemes, tp=self.tp) phonemes = np.asarray(phonemes, dtype=np.int32) return phonemes @@ -113,7 +115,7 @@ class MyDataset(Dataset): text = self._load_or_generate_phoneme_sequence(wav_file, text) else: text = np.asarray( - text_to_sequence(text, [self.cleaners]), dtype=np.int32) + text_to_sequence(text, [self.cleaners], tp=self.tp), dtype=np.int32) assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] @@ -193,7 +195,7 @@ class MyDataset(Dataset): mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] linear = [self.ap.spectrogram(w).astype('float32') for w in wav] - mel_lengths = [m.shape[1] for m in mel] + mel_lengths = [m.shape[1] for m in mel] # compute 'stop token' targets stop_targets = [ diff --git a/datasets/preprocess.py b/datasets/preprocess.py index a78abab9..64efc665 100644 --- a/datasets/preprocess.py +++ b/datasets/preprocess.py @@ -60,22 +60,6 @@ def tweb(root_path, meta_file): # return {'text': texts, 'wavs': wavs} -def mozilla_old(root_path, meta_file): - """Normalizes Mozilla meta data files to TTS format""" - txt_file = os.path.join(root_path, meta_file) - items = [] - speaker_name = "mozilla_old" - with open(txt_file, 'r') as ttf: - for line in ttf: - cols = line.split('|') - batch_no = int(cols[1].strip().split("_")[0]) - wav_folder = "batch{}".format(batch_no) - wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip()) - text = cols[0].strip() - items.append([text, wav_file, speaker_name]) - return items - - def mozilla(root_path, meta_file): """Normalizes Mozilla meta data files to TTS format""" txt_file = os.path.join(root_path, meta_file) @@ -91,6 +75,22 @@ def mozilla(root_path, meta_file): return items +def mozilla_de(root_path, meta_file): + """Normalizes Mozilla meta data files to TTS format""" + txt_file = os.path.join(root_path, meta_file) + items = [] + speaker_name = "mozilla" + with open(txt_file, 'r', encoding="ISO 8859-1") as ttf: + for line in ttf: + cols = line.strip().split('|') + wav_file = cols[0].strip() + text = cols[1].strip() + folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" + wav_file = os.path.join(root_path, folder_name, wav_file) + items.append([text, wav_file, speaker_name]) + return items + + def mailabs(root_path, meta_files=None): """Normalizes M-AI-Labs meta data files to TTS format""" speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/") diff --git a/layers/tacotron2.py b/layers/tacotron2.py index c195b277..fa76a6b2 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -64,7 +64,6 @@ class Encoder(nn.Module): def forward(self, x, input_lengths): x = self.convolutions(x) x = x.transpose(1, 2) - input_lengths = input_lengths.cpu().numpy() x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True) diff --git a/notebooks/Benchmark-PWGAN.ipynb b/notebooks/Benchmark-PWGAN.ipynb index 430d329f..840da10e 100644 --- a/notebooks/Benchmark-PWGAN.ipynb +++ b/notebooks/Benchmark-PWGAN.ipynb @@ -132,7 +132,7 @@ "outputs": [], "source": [ "# LOAD TTS MODEL\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "# multi speaker \n", "if CONFIG.use_speaker_embedding:\n", @@ -142,6 +142,10 @@ " speakers = []\n", " speaker_id = None\n", "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.characters)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n", diff --git a/notebooks/Benchmark.ipynb b/notebooks/Benchmark.ipynb index 7c528506..7d3a45cf 100644 --- a/notebooks/Benchmark.ipynb +++ b/notebooks/Benchmark.ipynb @@ -65,6 +65,7 @@ "from TTS.utils.text import text_to_sequence\n", "from TTS.utils.synthesis import synthesis\n", "from TTS.utils.visual import visualize\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "import IPython\n", "from IPython.display import Audio\n", @@ -81,13 +82,15 @@ "source": [ "def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n", " t_1 = time.time()\n", - " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n", + " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, style_wav=None, \n", + " truncated=False, enable_eos_bos_chars=CONFIG.enable_eos_bos_chars,\n", + " use_griffin_lim=use_gl)\n", " if CONFIG.model == \"Tacotron\" and not use_gl:\n", " # coorect the normalization differences b/w TTS and the Vocoder.\n", " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", - " mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n", - " mel_postnet_spec = ap_vocoder._normalize(mel_postnet_spec)\n", " if not use_gl:\n", + " mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n", + " mel_postnet_spec = ap_vocoder._normalize(mel_postnet_spec)\n", " waveform = wavernn.generate(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0).cuda(), batched=batched_wavernn, target=8000, overlap=400)\n", "\n", " print(\" > Run-time: {}\".format(time.time() - t_1))\n", @@ -108,7 +111,7 @@ "outputs": [], "source": [ "# Set constants\n", - "ROOT_PATH = '/media/erogol/data_ssd/Models/libri_tts/5099/'\n", + "ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-bn-December-23-2019_08+34AM-ffea133/'\n", "MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n", "CONFIG_PATH = ROOT_PATH + '/config.json'\n", "OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n", @@ -116,7 +119,7 @@ "VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/checkpoint_433000.pth.tar\"\n", "VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/config.json\"\n", "VOCODER_CONFIG = load_config(VOCODER_CONFIG_PATH)\n", - "use_cuda = False\n", + "use_cuda = True\n", "\n", "# Set some config fields manually for testing\n", "# CONFIG.windowing = False\n", @@ -127,7 +130,7 @@ "# CONFIG.stopnet = True\n", "\n", "# Set the vocoder\n", - "use_gl = False # use GL if True\n", + "use_gl = True # use GL if True\n", "batched_wavernn = True # use batched wavernn inference if True" ] }, @@ -138,8 +141,6 @@ "outputs": [], "source": [ "# LOAD TTS MODEL\n", - "from utils.text.symbols import symbols, phonemes\n", - "\n", "# multi speaker \n", "if CONFIG.use_speaker_embedding:\n", " speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n", @@ -148,6 +149,10 @@ " speakers = []\n", " speaker_id = None\n", "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.characters)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n", @@ -181,7 +186,7 @@ "metadata": {}, "outputs": [], "source": [ - "# LOAD WAVERNN\n", + "# LOAD WAVERNN - Make sure you downloaded the model and installed the module\n", "if use_gl == False:\n", " from WaveRNN.models.wavernn import Model\n", " from WaveRNN.utils.audio import AudioProcessor as AudioProcessorVocoder\n", @@ -533,7 +538,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.3" + "version": "3.7.4" } }, "nbformat": 4, diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index 20038f78..b5a88611 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -37,7 +37,7 @@ "from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.visual import plot_spectrogram\n", "from TTS.utils.generic_utils import load_config, setup_model, sequence_mask\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "%matplotlib inline\n", "\n", @@ -94,6 +94,10 @@ "metadata": {}, "outputs": [], "source": [ + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in C.keys():\n", + " symbols, phonemes = make_symbols(**C.characters)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "# TODO: multiple speaker\n", @@ -116,7 +120,7 @@ "preprocessor = importlib.import_module('TTS.datasets.preprocess')\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", - "dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", + "dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" ] }, diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb index a1867d13..9d3e5e75 100644 --- a/notebooks/TestAttention.ipynb +++ b/notebooks/TestAttention.ipynb @@ -100,7 +100,7 @@ "outputs": [], "source": [ "# LOAD TTS MODEL\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "# multi speaker \n", "if CONFIG.use_speaker_embedding:\n", @@ -110,6 +110,10 @@ " speakers = []\n", " speaker_id = None\n", "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'characters' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.characters)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n", diff --git a/server/README.md b/server/README.md index 95297225..0563ef94 100644 --- a/server/README.md +++ b/server/README.md @@ -6,6 +6,10 @@ Instructions below are based on a Ubuntu 18.04 machine, but it should be simple #### Development server: +##### Using server.py +If you have the environment set already for TTS, then you can directly call ```setup.py```. + +##### Using .whl 1. apt-get install -y espeak libsndfile1 python3-venv 2. python3 -m venv /tmp/venv 3. source /tmp/venv/bin/activate diff --git a/server/server.py b/server/server.py index 3be66f9e..705937e2 100644 --- a/server/server.py +++ b/server/server.py @@ -14,10 +14,13 @@ def create_argparser(): parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file') parser.add_argument('--tts_config', type=str, help='path to TTS config.json file') parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') - parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') - parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.') - parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.') + parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.') + parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') + parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.') + parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') @@ -26,28 +29,35 @@ def create_argparser(): synthesizer = None -embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar') -config_file = os.path.join(embedded_model_folder, 'config.json') +embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -# Default options with embedded model files -if os.path.isfile(checkpoint_file): - default_tts_checkpoint = checkpoint_file -else: - default_tts_checkpoint = None +embedded_tts_folder = os.path.join(embedded_models_folder, 'tts') +tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar') +tts_config_file = os.path.join(embedded_tts_folder, 'config.json') -if os.path.isfile(config_file): - default_tts_config = config_file -else: - default_tts_config = None +embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn') +wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar') +wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json') + +embedded_pwgan_folder = os.path.join(embedded_models_folder, 'pwgan') +pwgan_checkpoint_file = os.path.join(embedded_pwgan_folder, 'checkpoint.pkl') +pwgan_config_file = os.path.join(embedded_pwgan_folder, 'config.yml') args = create_argparser().parse_args() -# If these were not specified in the CLI args, use default values -if not args.tts_checkpoint: - args.tts_checkpoint = default_tts_checkpoint -if not args.tts_config: - args.tts_config = default_tts_config +# If these were not specified in the CLI args, use default values with embedded model files +if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file): + args.tts_checkpoint = tts_checkpoint_file +if not args.tts_config and os.path.isfile(tts_config_file): + args.tts_config = tts_config_file +if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file): + args.wavernn_file = wavernn_checkpoint_file +if not args.wavernn_config and os.path.isfile(wavernn_config_file): + args.wavernn_config = wavernn_config_file +if not args.pwgan_file and os.path.isfile(pwgan_checkpoint_file): + args.pwgan_file = pwgan_checkpoint_file +if not args.pwgan_config and os.path.isfile(pwgan_config_file): + args.pwgan_config = pwgan_config_file synthesizer = Synthesizer(args) diff --git a/server/synthesizer.py b/server/synthesizer.py index d8852a3e..f73b73fc 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -1,17 +1,20 @@ import io -import os +import re +import sys import numpy as np import torch -import sys +import yaml from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import load_config, setup_model -from TTS.utils.text import phonemes, symbols from TTS.utils.speakers import load_speaker_mapping +# pylint: disable=unused-wildcard-import +# pylint: disable=wildcard-import from TTS.utils.synthesis import * -import re +from TTS.utils.text import make_symbols, phonemes, symbols + alphabets = r"([A-Za-z])" prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]" suffixes = r"(Inc|Ltd|Jr|Sr|Co)" @@ -23,6 +26,7 @@ websites = r"[.](com|net|org|io|gov)" class Synthesizer(object): def __init__(self, config): self.wavernn = None + self.pwgan = None self.config = config self.use_cuda = self.config.use_cuda if self.use_cuda: @@ -30,28 +34,38 @@ class Synthesizer(object): self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.config.use_cuda) if self.config.wavernn_lib_path: - self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path, - self.config.wavernn_file, self.config.wavernn_config, - self.config.use_cuda) + self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file, + self.config.wavernn_config, self.config.use_cuda) + if self.config.pwgan_lib_path: + self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file, + self.config.pwgan_config, self.config.use_cuda) def load_tts(self, tts_checkpoint, tts_config, use_cuda): + # pylint: disable=global-statement + global symbols, phonemes + print(" > Loading TTS model ...") print(" | > model config: ", tts_config) print(" | > checkpoint file: ", tts_checkpoint) + self.tts_config = load_config(tts_config) self.use_phonemes = self.tts_config.use_phonemes self.ap = AudioProcessor(**self.tts_config.audio) + + if 'characters' in self.tts_config.keys(): + symbols, phonemes = make_symbols(**self.tts_config.characters) + if self.use_phonemes: self.input_size = len(phonemes) else: self.input_size = len(symbols) - # load speakers + # TODO: fix this for multi-speaker model - load speakers if self.config.tts_speakers is not None: - self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers)) + self.tts_speakers = load_speaker_mapping(self.config.tts_speakers) num_speakers = len(self.tts_speakers) else: num_speakers = 0 - self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) + self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) # load model state cp = torch.load(tts_checkpoint, map_location=torch.device('cpu')) # load the model @@ -63,16 +77,17 @@ class Synthesizer(object): if 'r' in cp: self.tts_model.decoder.set_r(cp['r']) - def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): + def load_wavernn(self, lib_path, model_file, model_config, use_cuda): # TODO: set a function in wavernn code base for model setup and call it here. - sys.path.append(lib_path) # set this if TTS is not installed globally + sys.path.append(lib_path) # set this if WaveRNN is not installed globally + #pylint: disable=import-outside-toplevel from WaveRNN.models.wavernn import Model - wavernn_config = os.path.join(model_path, model_config) - model_file = os.path.join(model_path, model_file) print(" > Loading WaveRNN model ...") - print(" | > model config: ", wavernn_config) + print(" | > model config: ", model_config) print(" | > model file: ", model_file) - self.wavernn_config = load_config(wavernn_config) + self.wavernn_config = load_config(model_config) + # This is the default architecture we use for our models. + # You might need to update it self.wavernn = Model( rnn_dims=512, fc_dims=512, @@ -80,7 +95,7 @@ class Synthesizer(object): mulaw=self.wavernn_config.mulaw, pad=self.wavernn_config.pad, use_aux_net=self.wavernn_config.use_aux_net, - use_upsample_net = self.wavernn_config.use_upsample_net, + use_upsample_net=self.wavernn_config.use_upsample_net, upsample_factors=self.wavernn_config.upsample_factors, feat_dims=80, compute_dims=128, @@ -91,18 +106,35 @@ class Synthesizer(object): ).cuda() check = torch.load(model_file) - self.wavernn.load_state_dict(check['model']) + self.wavernn.load_state_dict(check['model'], map_location="cpu") if use_cuda: self.wavernn.cuda() self.wavernn.eval() + def load_pwgan(self, lib_path, model_file, model_config, use_cuda): + sys.path.append(lib_path) # set this if ParallelWaveGAN is not installed globally + #pylint: disable=import-outside-toplevel + from parallel_wavegan.models import ParallelWaveGANGenerator + print(" > Loading PWGAN model ...") + print(" | > model config: ", model_config) + print(" | > model file: ", model_file) + with open(model_config) as f: + self.pwgan_config = yaml.load(f, Loader=yaml.Loader) + self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"]) + self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"]) + self.pwgan.remove_weight_norm() + if use_cuda: + self.pwgan.cuda() + self.pwgan.eval() + def save_wav(self, wav, path): # wav *= 32767 / max(1e-8, np.max(np.abs(wav))) wav = np.array(wav) self.ap.save_wav(wav, path) - def split_into_sentences(self, text): - text = " " + text + " " + @staticmethod + def split_into_sentences(text): + text = " " + text + " " text = text.replace("\n", " ") text = re.sub(prefixes, "\\1", text) text = re.sub(websites, "\\1", text) @@ -129,15 +161,13 @@ class Synthesizer(object): text = text.replace("", ".") sentences = text.split("") sentences = sentences[:-1] - sentences = [s.strip() for s in sentences] + sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences return sentences def tts(self, text): wavs = [] sens = self.split_into_sentences(text) print(sens) - if not sens: - sens = [text+'.'] for sen in sens: # preprocess the given text inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda) @@ -148,9 +178,16 @@ class Synthesizer(object): postnet_output, decoder_output, _ = parse_outputs( postnet_output, decoder_output, alignments) - if self.wavernn: - postnet_output = postnet_output[0].data.cpu().numpy() - wav = self.wavernn.generate(torch.FloatTensor(postnet_output.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550) + if self.pwgan: + vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.use_cuda: + vocoder_input.cuda() + wav = self.pwgan.inference(vocoder_input, hop_size=self.ap.hop_length) + elif self.wavernn: + vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) + if self.use_cuda: + vocoder_input.cuda() + wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550) else: wav = inv_spectrogram(postnet_output, self.ap, self.tts_config) # trim silence diff --git a/setup.py b/setup.py index 63782800..f92dac8a 100644 --- a/setup.py +++ b/setup.py @@ -61,10 +61,11 @@ package_data = ['server/templates/*'] if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config: print('Embedding model in wheel file...') model_dir = os.path.join('server', 'model') - os.makedirs(model_dir, exist_ok=True) - embedded_checkpoint_path = os.path.join(model_dir, 'checkpoint.pth.tar') + tts_dir = os.path.join(model_dir, 'tts') + os.makedirs(tts_dir, exist_ok=True) + embedded_checkpoint_path = os.path.join(tts_dir, 'checkpoint.pth.tar') shutil.copy(args.checkpoint, embedded_checkpoint_path) - embedded_config_path = os.path.join(model_dir, 'config.json') + embedded_config_path = os.path.join(tts_dir, 'config.json') shutil.copy(args.model_config, embedded_config_path) package_data.extend([embedded_checkpoint_path, embedded_config_path]) diff --git a/synthesize.py b/synthesize.py index cb0ee8af..1f1ce36f 100644 --- a/synthesize.py +++ b/synthesize.py @@ -1,3 +1,4 @@ +# pylint: disable=redefined-outer-name, unused-argument import os import time import argparse @@ -7,7 +8,7 @@ import string from TTS.utils.synthesis import synthesis from TTS.utils.generic_utils import load_config, setup_model -from TTS.utils.text.symbols import symbols, phonemes +from TTS.utils.text.symbols import make_symbols, symbols, phonemes from TTS.utils.audio import AudioProcessor @@ -25,14 +26,16 @@ def tts(model, t_1 = time.time() use_vocoder_model = vocoder_model is not None waveform, alignment, _, postnet_output, stop_tokens = synthesis( - model, text, C, use_cuda, ap, speaker_id, False, - C.enable_eos_bos_chars) + model, text, C, use_cuda, ap, speaker_id, style_wav=False, + truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars, + use_griffin_lim=(not use_vocoder_model), do_trim_silence=True) + if C.model == "Tacotron" and use_vocoder_model: postnet_output = ap.out_linear_to_mel(postnet_output.T).T # correct if there is a scale difference b/w two models - postnet_output = ap._denormalize(postnet_output) - postnet_output = ap_vocoder._normalize(postnet_output) if use_vocoder_model: + postnet_output = ap._denormalize(postnet_output) + postnet_output = ap_vocoder._normalize(postnet_output) vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) waveform = vocoder_model.generate( vocoder_input.cuda() if use_cuda else vocoder_input, @@ -45,6 +48,8 @@ def tts(model, if __name__ == "__main__": + global symbols, phonemes + parser = argparse.ArgumentParser() parser.add_argument('text', type=str, help='Text to generate speech.') parser.add_argument('config_path', @@ -58,7 +63,7 @@ if __name__ == "__main__": parser.add_argument( 'out_path', type=str, - help='Path to save final wav file.', + help='Path to save final wav file. Wav file will be names as the text given.', ) parser.add_argument('--use_cuda', type=bool, @@ -102,6 +107,10 @@ if __name__ == "__main__": # load the audio processor ap = AudioProcessor(**C.audio) + # if the vocabulary was passed, replace the default + if 'characters' in C.keys(): + symbols, phonemes = make_symbols(**C.characters) + # load speakers if args.speakers_json != '': speakers = json.load(open(args.speakers_json, 'r')) diff --git a/tests/inputs/server_config.json b/tests/inputs/server_config.json index 3988db4c..7f5a60fb 100644 --- a/tests/inputs/server_config.json +++ b/tests/inputs/server_config.json @@ -3,9 +3,11 @@ "tts_config":"dummy_model_config.json", // tts config.json file "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. "wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis. - "wavernn_path": null, // wavernn model root path "wavernn_file": null, // wavernn checkpoint file name "wavernn_config": null, // wavernn config file + "pwgan_lib_path": null, + "pwgan_file": null, + "pwgan_config": null, "is_wavernn_batched":true, "port": 5002, "use_cuda": false, diff --git a/tests/test_config.json b/tests/test_config.json index 0cd3d751..6d63e6ab 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -19,6 +19,16 @@ "mel_fmax": 7600, // maximum freq level for mel-spec. Tune for dataset!! "do_trim_silence": false }, + + "characters":{ + "pad": "_", + "eos": "~", + "bos": "^", + "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + "punctuations":"!'(),-.:;? ", + "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + }, + "hidden_size": 128, "embedding_size": 256, "text_cleaner": "english_cleaners", diff --git a/tests/test_demo_server.py b/tests/test_demo_server.py index c343a6a4..a0837686 100644 --- a/tests/test_demo_server.py +++ b/tests/test_demo_server.py @@ -5,13 +5,19 @@ import torch as T from TTS.server.synthesizer import Synthesizer from TTS.tests import get_tests_input_path, get_tests_output_path -from TTS.utils.text.symbols import phonemes, symbols +from TTS.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model class DemoServerTest(unittest.TestCase): + # pylint: disable=R0201 def _create_random_model(self): + # pylint: disable=global-statement + global symbols, phonemes config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json')) + if 'characters' in config.keys(): + symbols, phonemes = make_symbols(**config.characters) + num_chars = len(phonemes) if config.use_phonemes else len(symbols) model = setup_model(num_chars, 0, config) output_path = os.path.join(get_tests_output_path()) diff --git a/tests/test_loader.py b/tests/test_loader.py index 751bc181..d835c5d3 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -37,7 +37,8 @@ class TestTTSDataset(unittest.TestCase): r, c.text_cleaner, ap=self.ap, - meta_data=items, + meta_data=items, + tp=c.characters if 'characters' in c.keys() else None, batch_group_size=bgs, min_seq_len=c.min_seq_len, max_seq_len=float("inf"), @@ -137,9 +138,7 @@ class TestTTSDataset(unittest.TestCase): # NOTE: Below needs to check == 0 but due to an unknown reason # there is a slight difference between two matrices. # TODO: Check this assert cond more in detail. - assert abs((abs(mel.T) - - abs(mel_dl) - ).sum()) < 1e-5, (abs(mel.T) - abs(mel_dl)).sum() + assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max() # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() diff --git a/tests/test_server_package.sh b/tests/test_server_package.sh index 01e42843..9fe5e8b1 100755 --- a/tests/test_server_package.sh +++ b/tests/test_server_package.sh @@ -11,7 +11,7 @@ source /tmp/venv/bin/activate pip install --quiet --upgrade pip setuptools wheel rm -f dist/*.whl -python setup.py bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json +python setup.py --quiet bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json pip install --quiet dist/TTS*.whl python -m TTS.server.server & diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index 0ecb9962..6c0c7058 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -1,7 +1,14 @@ +import os +# pylint: disable=unused-wildcard-import +# pylint: disable=wildcard-import +# pylint: disable=unused-import import unittest -import torch as T - from TTS.utils.text import * +from TTS.tests import get_tests_path +from TTS.utils.generic_utils import load_config + +TESTS_PATH = get_tests_path() +conf = load_config(os.path.join(TESTS_PATH, 'test_config.json')) def test_phoneme_to_sequence(): text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" @@ -9,67 +16,80 @@ def test_phoneme_to_sequence(): lang = "en-us" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # multiple punctuations text = "Be a voice, not an! echo?" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # not ending with punctuation text = "Be a voice, not an! echo" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # original text = "Be a voice, not an echo!" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # extra space after the sentence text = "Be a voice, not an! echo. " sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # extra space after the sentence text = "Be a voice, not an! echo. " sequence = phoneme_to_sequence(text, text_cleaner, lang, True) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~" print(text_hat) print(len(sequence)) - assert text_hat == gt + assert text_hat == text_hat_with_params == gt # padding char text = "_Be a _voice, not an! echo_" sequence = phoneme_to_sequence(text, text_cleaner, lang) text_hat = sequence_to_phoneme(sequence) + sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters) + text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters) gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" print(text_hat) print(len(sequence)) - assert text_hat == gt - + assert text_hat == text_hat_with_params == gt def test_text2phone(): text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!" lang = "en-us" - phonemes = text2phone(text, lang) - assert gt == phonemes, f"\n{phonemes} \n vs \n{gt}" + ph = text2phone(text, lang) + assert gt == ph, f"\n{phonemes} \n vs \n{gt}" \ No newline at end of file diff --git a/train.py b/train.py index e8c240f3..4bb22a34 100644 --- a/train.py +++ b/train.py @@ -20,12 +20,12 @@ from TTS.utils.generic_utils import ( get_git_branch, load_config, remove_experiment_folder, save_best_model, save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file, setup_model, gradual_training_scheduler, KeepAverage, - set_weight_decay) + set_weight_decay, check_config) from TTS.utils.logger import Logger from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ get_speakers from TTS.utils.synthesis import synthesis -from TTS.utils.text.symbols import phonemes, symbols +from TTS.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.visual import plot_alignment, plot_spectrogram from TTS.datasets.preprocess import load_meta_data from TTS.utils.radam import RAdam @@ -49,6 +49,7 @@ def setup_loader(ap, r, is_val=False, verbose=False): c.text_cleaner, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, + tp=c.characters if 'characters' in c.keys() else None, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, @@ -515,9 +516,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name - global meta_data_train, meta_data_eval + # pylint: disable=global-variable-undefined + global meta_data_train, meta_data_eval, symbols, phonemes # Audio processor ap = AudioProcessor(**c.audio) + if 'characters' in c.keys(): + symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: @@ -687,6 +691,7 @@ if __name__ == '__main__': # setup output paths and read configs c = load_config(args.config_path) + check_config(c) _ = os.path.dirname(os.path.realpath(__file__)) OUT_PATH = args.continue_path diff --git a/utils/generic_utils.py b/utils/generic_utils.py index cf1a83a6..cf0a05b4 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -389,3 +389,131 @@ class KeepAverage(): def update_values(self, value_dict): for key, value in value_dict.items(): self.update_value(key, value) + + +def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None): + if restricted: + assert name in c.keys(), f' [!] {name} not defined in config.json' + if name in c.keys(): + if max_val: + assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}' + if min_val: + assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' + if enum_list: + assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' + if val_type: + assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + + +def check_config(c): + _check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) + _check_argument('run_name', c, restricted=True, val_type=str) + _check_argument('run_description', c, val_type=str) + + # AUDIO + _check_argument('audio', c, restricted=True, val_type=dict) + + # audio processing parameters + _check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) + _check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) + _check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) + _check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000) + _check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000) + _check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) + _check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) + _check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) + _check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) + _check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + + # vocabulary parameters + _check_argument('characters', c, restricted=False, val_type=dict) + _check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + _check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + + # normalization parameters + _check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) + _check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) + _check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) + _check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) + _check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) + _check_argument('trim_db', c['audio'], restricted=True, val_type=int) + + # training parameters + _check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) + _check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) + _check_argument('r', c, restricted=True, val_type=int, min_val=1) + _check_argument('gradual_training', c, restricted=False, val_type=list) + _check_argument('loss_masking', c, restricted=True, val_type=bool) + # _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) + + # validation parameters + _check_argument('run_eval', c, restricted=True, val_type=bool) + _check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) + _check_argument('test_sentences_file', c, restricted=False, val_type=str) + + # optimizer + _check_argument('noam_schedule', c, restricted=False, val_type=bool) + _check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) + _check_argument('epochs', c, restricted=True, val_type=int, min_val=1) + _check_argument('lr', c, restricted=True, val_type=float, min_val=0) + _check_argument('wd', c, restricted=True, val_type=float, min_val=0) + _check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) + _check_argument('seq_len_norm', c, restricted=True, val_type=bool) + + # tacotron prenet + _check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1) + _check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn']) + _check_argument('prenet_dropout', c, restricted=True, val_type=bool) + + # attention + _check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original']) + _check_argument('attention_heads', c, restricted=True, val_type=int) + _check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax']) + _check_argument('windowing', c, restricted=True, val_type=bool) + _check_argument('use_forward_attn', c, restricted=True, val_type=bool) + _check_argument('forward_attn_mask', c, restricted=True, val_type=bool) + _check_argument('transition_agent', c, restricted=True, val_type=bool) + _check_argument('transition_agent', c, restricted=True, val_type=bool) + _check_argument('location_attn', c, restricted=True, val_type=bool) + _check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) + + # stopnet + _check_argument('stopnet', c, restricted=True, val_type=bool) + _check_argument('separate_stopnet', c, restricted=True, val_type=bool) + + # tensorboard + _check_argument('print_step', c, restricted=True, val_type=int, min_val=1) + _check_argument('save_step', c, restricted=True, val_type=int, min_val=1) + _check_argument('checkpoint', c, restricted=True, val_type=bool) + _check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + + # dataloading + _check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=['english_cleaners', 'phoneme_cleaners', 'transliteration_cleaners', 'basic_cleaners']) + _check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) + _check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) + _check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) + _check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) + _check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) + _check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) + + # paths + _check_argument('output_path', c, restricted=True, val_type=str) + + # multi-speaker gst + _check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) + _check_argument('style_wav_for_test', c, restricted=True, val_type=str) + _check_argument('use_gst', c, restricted=True, val_type=bool) + + # datasets - checking only the first entry + _check_argument('datasets', c, restricted=True, val_type=list) + for dataset_entry in c['datasets']: + _check_argument('name', dataset_entry, restricted=True, val_type=str) + _check_argument('path', dataset_entry, restricted=True, val_type=str) + _check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str) + _check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) \ No newline at end of file diff --git a/utils/synthesis.py b/utils/synthesis.py index f066228a..42f0408c 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -9,10 +9,11 @@ def text_to_seqvec(text, CONFIG, use_cuda): if CONFIG.use_phonemes: seq = np.asarray( phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars), + CONFIG.enable_eos_bos_chars, + tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) else: - seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) + seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) # torch tensor chars_var = torch.from_numpy(seq).unsqueeze(0) if use_cuda: @@ -78,6 +79,7 @@ def synthesis(model, style_wav=None, truncated=False, enable_eos_bos_chars=False, #pylint: disable=unused-argument + use_griffin_lim=False, do_trim_silence=False): """Synthesize voice for the given text. @@ -111,8 +113,10 @@ def synthesis(model, postnet_output, decoder_output, alignment = parse_outputs( postnet_output, decoder_output, alignments) # plot results - wav = inv_spectrogram(postnet_output, ap, CONFIG) - # trim silence - if do_trim_silence: - wav = trim_silence(wav, ap) + wav = None + if use_griffin_lim: + wav = inv_spectrogram(postnet_output, ap, CONFIG) + # trim silence + if do_trim_silence: + wav = trim_silence(wav, ap) return wav, alignment, decoder_output, postnet_output, stop_tokens diff --git a/utils/text/__init__.py b/utils/text/__init__.py index e6842dfa..2cbdafc9 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -4,15 +4,15 @@ import re import phonemizer from phonemizer.phonemize import phonemize from TTS.utils.text import cleaners -from TTS.utils.text.symbols import symbols, phonemes, _phoneme_punctuations, _bos, \ +from TTS.utils.text.symbols import make_symbols, symbols, phonemes, _phoneme_punctuations, _bos, \ _eos # Mappings from symbol to numeric ID and vice versa: -_SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)} -_ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)} +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} -_PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)} -_ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)} +_phonemes_to_id = {s: i for i, s in enumerate(phonemes)} +_id_to_phonemes = {i: s for i, s in enumerate(phonemes)} # Regular expression matching text enclosed in curly braces: _CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)') @@ -38,14 +38,11 @@ def text2phone(text, language): if text[-1] == punctuations[-1]: for punct in punctuations[:-1]: ph = ph.replace('| |\n', '|'+punct+'| |', 1) - try: ph = ph + punctuations[-1] - except: - print(text) else: for punct in punctuations: ph = ph.replace('| |\n', '|'+punct+'| |', 1) - elif float(phonemizer.__version__) == 2.1: + elif float(phonemizer.__version__) > 2.1: ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True) # this is a simple fix for phonemizer. # https://github.com/bootphon/phonemizer/issues/32 @@ -59,11 +56,25 @@ def text2phone(text, language): return ph -def pad_with_eos_bos(phoneme_sequence): - return [_PHONEMES_TO_ID[_bos]] + list(phoneme_sequence) + [_PHONEMES_TO_ID[_eos]] +def pad_with_eos_bos(phoneme_sequence, tp=None): + # pylint: disable=global-statement + global _phonemes_to_id, _bos, _eos + if tp: + _bos = tp['bos'] + _eos = tp['eos'] + _, _phonemes = make_symbols(**tp) + _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} + + return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] -def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False): +def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None): + # pylint: disable=global-statement + global _phonemes_to_id + if tp: + _, _phonemes = make_symbols(**tp) + _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} + sequence = [] text = text.replace(":", "") clean_text = _clean_text(text, cleaner_names) @@ -75,21 +86,27 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False): sequence += _phoneme_to_sequence(phoneme) # Append EOS char if enable_eos_bos: - sequence = pad_with_eos_bos(sequence) + sequence = pad_with_eos_bos(sequence, tp=tp) return sequence -def sequence_to_phoneme(sequence): +def sequence_to_phoneme(sequence, tp=None): + # pylint: disable=global-statement '''Converts a sequence of IDs back to a string''' + global _id_to_phonemes result = '' + if tp: + _, _phonemes = make_symbols(**tp) + _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} + for symbol_id in sequence: - if symbol_id in _ID_TO_PHONEMES: - s = _ID_TO_PHONEMES[symbol_id] + if symbol_id in _id_to_phonemes: + s = _id_to_phonemes[symbol_id] result += s return result.replace('}{', ' ') -def text_to_sequence(text, cleaner_names): +def text_to_sequence(text, cleaner_names, tp=None): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded @@ -102,6 +119,12 @@ def text_to_sequence(text, cleaner_names): Returns: List of integers corresponding to the symbols in the text ''' + # pylint: disable=global-statement + global _symbol_to_id + if tp: + _symbols, _ = make_symbols(**tp) + _symbol_to_id = {s: i for i, s in enumerate(_symbols)} + sequence = [] # Check for curly braces and treat their contents as ARPAbet: while text: @@ -116,12 +139,18 @@ def text_to_sequence(text, cleaner_names): return sequence -def sequence_to_text(sequence): +def sequence_to_text(sequence, tp=None): '''Converts a sequence of IDs back to a string''' + # pylint: disable=global-statement + global _id_to_symbol + if tp: + _symbols, _ = make_symbols(**tp) + _id_to_symbol = {i: s for i, s in enumerate(_symbols)} + result = '' for symbol_id in sequence: - if symbol_id in _ID_TO_SYMBOL: - s = _ID_TO_SYMBOL[symbol_id] + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: if len(s) > 1 and s[0] == '@': s = '{%s}' % s[1:] @@ -139,11 +168,11 @@ def _clean_text(text, cleaner_names): def _symbols_to_sequence(syms): - return [_SYMBOL_TO_ID[s] for s in syms if _should_keep_symbol(s)] + return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)] def _phoneme_to_sequence(phons): - return [_PHONEMES_TO_ID[s] for s in list(phons) if _should_keep_phoneme(s)] + return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)] def _arpabet_to_sequence(text): @@ -151,8 +180,8 @@ def _arpabet_to_sequence(text): def _should_keep_symbol(s): - return s in _SYMBOL_TO_ID and s not in ['~', '^', '_'] + return s in _symbol_to_id and s not in ['~', '^', '_'] def _should_keep_phoneme(p): - return p in _PHONEMES_TO_ID and p not in ['~', '^', '_'] + return p in _phonemes_to_id and p not in ['~', '^', '_'] diff --git a/utils/text/symbols.py b/utils/text/symbols.py index ee6fd2cf..544277c5 100644 --- a/utils/text/symbols.py +++ b/utils/text/symbols.py @@ -5,6 +5,18 @@ Defines the set of symbols used in text input to the model. The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' +def make_symbols(characters, phonemes, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name + ''' Function to create symbols and phonemes ''' + _phonemes_sorted = sorted(list(phonemes)) + + # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): + _arpabet = ['@' + s for s in _phonemes_sorted] + + # Export all symbols: + _symbols = [pad, eos, bos] + list(characters) + _arpabet + _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) + + return _symbols, _phonemes _pad = '_' _eos = '~' @@ -20,14 +32,9 @@ _pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðsz _suprasegmentals = 'ˈˌːˑ' _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' _diacrilics = 'ɚ˞ɫ' -_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics)) +_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics -# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): -_arpabet = ['@' + s for s in _phonemes] - -# Export all symbols: -symbols = [_pad, _eos, _bos] + list(_characters) + _arpabet -phonemes = [_pad, _eos, _bos] + list(_phonemes) + list(_punctuations) +symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos) # Generate ALIEN language # from random import shuffle diff --git a/utils/visual.py b/utils/visual.py index ab513666..1cb9ac5d 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -54,9 +54,10 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON plt.xlabel("Decoder timestamp", fontsize=label_fontsize) plt.ylabel("Encoder timestamp", fontsize=label_fontsize) if CONFIG.use_phonemes: - seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars) - text = sequence_to_phoneme(seq) + seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) + text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) print(text) + plt.yticks(range(len(text)), list(text)) plt.colorbar()