Merge branch 'Edresson-dev' into dev

This commit is contained in:
erogol 2020-03-09 10:22:09 +01:00
commit c6440c257e
27 changed files with 481 additions and 181 deletions

19
.github/stale.yml vendored Normal file
View File

@ -0,0 +1,19 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 60
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 7
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
# Label to use when marking an issue as stale
staleLabel: wontfix
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions. You might also look our discourse page for further help.
https://discourse.mozilla.org/c/tts
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: false

View File

@ -115,10 +115,7 @@ In case of any error or intercepted execution, if there is no checkpoint yet und
You can also enjoy Tensorboard, if you point Tensorboard argument```--logdir``` to the experiment folder. You can also enjoy Tensorboard, if you point Tensorboard argument```--logdir``` to the experiment folder.
## Testing ## [Testing and Examples](https://github.com/mozilla/TTS/wiki/Examples-using-TTS)
Best way to test your network is to use Notebooks under ```notebooks``` folder.
There is also a good [CoLab](https://colab.research.google.com/github/tugstugi/dl-colab-notebooks/blob/master/notebooks/Mozilla_TTS_WaveRNN.ipynb) sample using pre-trained models (by @tugstugi).
## Contribution guidelines ## Contribution guidelines
This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines.](https://www.mozilla.org/about/governance/policies/participation/) This repository is governed by Mozilla's code of conduct and etiquette guidelines. For more details, please read the [Mozilla Community Participation Guidelines.](https://www.mozilla.org/about/governance/policies/participation/)
@ -139,28 +136,7 @@ If you like to use TTS to try a new idea and like to share your experiments with
- Share your results as you proceed. (Tensorboard log files, audio results, visuals etc.) - Share your results as you proceed. (Tensorboard log files, audio results, visuals etc.)
- Use LJSpeech dataset (for English) if you like to compare results with the released models. (It is the most open scalable dataset for quick experimentation) - Use LJSpeech dataset (for English) if you like to compare results with the released models. (It is the most open scalable dataset for quick experimentation)
## Contact/Getting Help ## [Contact/Getting Help](https://github.com/mozilla/TTS/wiki/Contact-and-Getting-Help)
- [Wiki](https://github.com/mozilla/TTS/wiki)
- [Discourse Forums](https://discourse.mozilla.org/c/tts) - If your question is not addressed in the Wiki, the Discourse Forums is the next place to look. They contain conversations on General Topics, Using TTS, and TTS Development.
- [Issues](https://github.com/mozilla/TTS/issues) - Finally, if all else fails, you can open an issue in our repo.
<!--## What is new with TTS
If you train TTS with LJSpeech dataset, you start to hear reasonable results after 12.5K iterations with batch size 32. This is the fastest training with character-based methods up to our knowledge. Out implementation is also quite robust against long sentences.
- Location sensitive attention ([ref](https://arxiv.org/pdf/1506.07503.pdf)). Attention is a vital part of text2speech models. Therefore, it is important to use an attention mechanism that suits the diagonal nature of the problem where the output strictly aligns with the text monotonically. Location sensitive attention performs better by looking into the previous alignment vectors and learns diagonal attention more easily. Yet, I believe there is a good space for research at this front to find a better solution.
- Attention smoothing with sigmoid ([ref](https://arxiv.org/pdf/1506.07503.pdf)). Attention weights are computed by normalized sigmoid values instead of softmax for sharper values. That enables the model to pick multiple highly scored inputs for alignments while reducing the noise.
- Weight decay ([ref](http://www.fast.ai/2018/07/02/adam-weight-decay/)). After a certain point of the training, you might observe the model over-fitting. That is, the model is able to pronounce words probably better but the quality of the speech quality gets lower and sometimes attention alignment gets disoriented.
- Stop token prediction with an additional module. The original Tacotron model does not propose a stop token to stop the decoding process. Therefore, you need to use heuristic measures to stop the decoder. Here, we prefer to use additional layers at the end to decide when to stop.
- Applying sigmoid to the model outputs. Since the output values are expected to be in the range [0, 1], we apply sigmoid to make things easier to approximate the expected output distribution.
- Phoneme based training is enabled for easier learning and robust pronunciation. It also makes easier to adapt TTS to the most languages without worrying about language specific characters.
- Configurable attention windowing at inference-time for robust alignment. It enforces network to only consider a certain window of encoder steps per iteration.
- Detailed Tensorboard stats for activation, weight and gradient values per layer. It is useful to detect defects and compare networks.
- Constant history window. Instead of using only the last frame of predictions, define a constant history queue. It enables training with gradually decreasing prediction frame (r=5 -> r=1) by only changing the last layer. For instance, you can train the model with r=5 and then fine-tune it with r=1 without any performance loss. It also solves well-known PreNet problem [#50](https://github.com/mozilla/TTS/issues/50).
- Initialization of hidden decoder states with Embedding layers instead of zero initialization.
One common question is to ask why we don't use Tacotron2 architecture. According to our ablation experiments, nothing, except Location Sensitive Attention, improves the performance, given the increase in the model size.
Please feel free to offer new changes and pull things off. We are happy to discuss and make things better.
-->
## Major TODOs ## Major TODOs
- [x] Implement the model. - [x] Implement the model.

View File

@ -11,6 +11,8 @@
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"win_length": 1024, // stft window length in ms. "win_length": 1024, // stft window length in ms.
"hop_length": 256, // stft window hop-lengh in ms. "hop_length": 256, // stft window hop-lengh in ms.
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
"min_level_db": -100, // normalization range "min_level_db": -100, // normalization range
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
@ -19,7 +21,7 @@
// Normalization parameters // Normalization parameters
"signal_norm": true, // normalize the spec values in range [0, 1] "signal_norm": true, // normalize the spec values in range [0, 1]
"symmetric_norm": true, // move normalization to range [-1, 1] "symmetric_norm": true, // move normalization to range [-1, 1]
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
"clip_norm": true, // clip normalized values into the range. "clip_norm": true, // clip normalized values into the range.
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
@ -27,6 +29,18 @@
"trim_db": 60 // threshold for timming silence. Set this according to your dataset. "trim_db": 60 // threshold for timming silence. Set this according to your dataset.
}, },
// VOCABULARY PARAMETERS
// if custom character set is not defined,
// default set in symbols.py is used
"characters":{
"pad": "_",
"eos": "~",
"bos": "^",
"characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
"punctuations":"!'(),-.:;? ",
"phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
},
// DISTRIBUTED TRAINING // DISTRIBUTED TRAINING
"distributed":{ "distributed":{
"backend": "nccl", "backend": "nccl",
@ -36,11 +50,12 @@
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
// TRAINING // TRAINING
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. "batch_size": 2, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
"eval_batch_size":16, "eval_batch_size":16,
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
"gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed.
"loss_masking": true, // enable / disable loss masking against the sequence padding. "loss_masking": true, // enable / disable loss masking against the sequence padding.
"grad_accum": 2, // if N > 1, enable gradient accumulation for N iterations. It is useful for low memory GPUs.
// VALIDATION // VALIDATION
"run_eval": true, "run_eval": true,
@ -49,7 +64,7 @@
// OPTIMIZER // OPTIMIZER
"noam_schedule": false, // use noam warmup and lr schedule. "noam_schedule": false, // use noam warmup and lr schedule.
"grad_clip": 1, // upper limit for gradients for clipping. "grad_clip": 1.0, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train. "epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"wd": 0.000001, // Weight decay weight. "wd": 0.000001, // Weight decay weight.

View File

@ -15,6 +15,7 @@ class MyDataset(Dataset):
text_cleaner, text_cleaner,
ap, ap,
meta_data, meta_data,
tp=None,
batch_group_size=0, batch_group_size=0,
min_seq_len=0, min_seq_len=0,
max_seq_len=float("inf"), max_seq_len=float("inf"),
@ -49,6 +50,7 @@ class MyDataset(Dataset):
self.min_seq_len = min_seq_len self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.ap = ap self.ap = ap
self.tp = tp
self.use_phonemes = use_phonemes self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language self.phoneme_language = phoneme_language
@ -75,13 +77,13 @@ class MyDataset(Dataset):
def _generate_and_cache_phoneme_sequence(self, text, cache_path): def _generate_and_cache_phoneme_sequence(self, text, cache_path):
"""generate a phoneme sequence from text. """generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the eos chars here. Instead we add those dynamically later; based on the
config option.""" config option."""
phonemes = phoneme_to_sequence(text, [self.cleaners], phonemes = phoneme_to_sequence(text, [self.cleaners],
language=self.phoneme_language, language=self.phoneme_language,
enable_eos_bos=False) enable_eos_bos=False,
tp=self.tp)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes) np.save(cache_path, phonemes)
return phonemes return phonemes
@ -101,7 +103,7 @@ class MyDataset(Dataset):
phonemes = self._generate_and_cache_phoneme_sequence(text, phonemes = self._generate_and_cache_phoneme_sequence(text,
cache_path) cache_path)
if self.enable_eos_bos: if self.enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes) phonemes = pad_with_eos_bos(phonemes, tp=self.tp)
phonemes = np.asarray(phonemes, dtype=np.int32) phonemes = np.asarray(phonemes, dtype=np.int32)
return phonemes return phonemes
@ -113,7 +115,7 @@ class MyDataset(Dataset):
text = self._load_or_generate_phoneme_sequence(wav_file, text) text = self._load_or_generate_phoneme_sequence(wav_file, text)
else: else:
text = np.asarray( text = np.asarray(
text_to_sequence(text, [self.cleaners]), dtype=np.int32) text_to_sequence(text, [self.cleaners], tp=self.tp), dtype=np.int32)
assert text.size > 0, self.items[idx][1] assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1]

View File

@ -60,22 +60,6 @@ def tweb(root_path, meta_file):
# return {'text': texts, 'wavs': wavs} # return {'text': texts, 'wavs': wavs}
def mozilla_old(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla_old"
with open(txt_file, 'r') as ttf:
for line in ttf:
cols = line.split('|')
batch_no = int(cols[1].strip().split("_")[0])
wav_folder = "batch{}".format(batch_no)
wav_file = os.path.join(root_path, wav_folder, "wavs_no_processing", cols[1].strip())
text = cols[0].strip()
items.append([text, wav_file, speaker_name])
return items
def mozilla(root_path, meta_file): def mozilla(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format""" """Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file) txt_file = os.path.join(root_path, meta_file)
@ -91,6 +75,22 @@ def mozilla(root_path, meta_file):
return items return items
def mozilla_de(root_path, meta_file):
"""Normalizes Mozilla meta data files to TTS format"""
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
with open(txt_file, 'r', encoding="ISO 8859-1") as ttf:
for line in ttf:
cols = line.strip().split('|')
wav_file = cols[0].strip()
text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
wav_file = os.path.join(root_path, folder_name, wav_file)
items.append([text, wav_file, speaker_name])
return items
def mailabs(root_path, meta_files=None): def mailabs(root_path, meta_files=None):
"""Normalizes M-AI-Labs meta data files to TTS format""" """Normalizes M-AI-Labs meta data files to TTS format"""
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/") speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")

View File

@ -64,7 +64,6 @@ class Encoder(nn.Module):
def forward(self, x, input_lengths): def forward(self, x, input_lengths):
x = self.convolutions(x) x = self.convolutions(x)
x = x.transpose(1, 2) x = x.transpose(1, 2)
input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(x, x = nn.utils.rnn.pack_padded_sequence(x,
input_lengths, input_lengths,
batch_first=True) batch_first=True)

View File

@ -132,7 +132,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# LOAD TTS MODEL\n", "# LOAD TTS MODEL\n",
"from TTS.utils.text.symbols import symbols, phonemes\n", "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n",
"\n", "\n",
"# multi speaker \n", "# multi speaker \n",
"if CONFIG.use_speaker_embedding:\n", "if CONFIG.use_speaker_embedding:\n",
@ -142,6 +142,10 @@
" speakers = []\n", " speakers = []\n",
" speaker_id = None\n", " speaker_id = None\n",
"\n", "\n",
"# if the vocabulary was passed, replace the default\n",
"if 'characters' in CONFIG.keys():\n",
" symbols, phonemes = make_symbols(**CONFIG.characters)\n",
"\n",
"# load the model\n", "# load the model\n",
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
"model = setup_model(num_chars, len(speakers), CONFIG)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n",

View File

@ -65,6 +65,7 @@
"from TTS.utils.text import text_to_sequence\n", "from TTS.utils.text import text_to_sequence\n",
"from TTS.utils.synthesis import synthesis\n", "from TTS.utils.synthesis import synthesis\n",
"from TTS.utils.visual import visualize\n", "from TTS.utils.visual import visualize\n",
"from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n",
"\n", "\n",
"import IPython\n", "import IPython\n",
"from IPython.display import Audio\n", "from IPython.display import Audio\n",
@ -81,13 +82,15 @@
"source": [ "source": [
"def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n", "def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n",
" t_1 = time.time()\n", " t_1 = time.time()\n",
" waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, False, CONFIG.enable_eos_bos_chars)\n", " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, style_wav=None, \n",
" truncated=False, enable_eos_bos_chars=CONFIG.enable_eos_bos_chars,\n",
" use_griffin_lim=use_gl)\n",
" if CONFIG.model == \"Tacotron\" and not use_gl:\n", " if CONFIG.model == \"Tacotron\" and not use_gl:\n",
" # coorect the normalization differences b/w TTS and the Vocoder.\n", " # coorect the normalization differences b/w TTS and the Vocoder.\n",
" mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n",
" mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n",
" mel_postnet_spec = ap_vocoder._normalize(mel_postnet_spec)\n",
" if not use_gl:\n", " if not use_gl:\n",
" mel_postnet_spec = ap._denormalize(mel_postnet_spec)\n",
" mel_postnet_spec = ap_vocoder._normalize(mel_postnet_spec)\n",
" waveform = wavernn.generate(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0).cuda(), batched=batched_wavernn, target=8000, overlap=400)\n", " waveform = wavernn.generate(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0).cuda(), batched=batched_wavernn, target=8000, overlap=400)\n",
"\n", "\n",
" print(\" > Run-time: {}\".format(time.time() - t_1))\n", " print(\" > Run-time: {}\".format(time.time() - t_1))\n",
@ -108,7 +111,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Set constants\n", "# Set constants\n",
"ROOT_PATH = '/media/erogol/data_ssd/Models/libri_tts/5099/'\n", "ROOT_PATH = '/home/erogol/Models/LJSpeech/ljspeech-bn-December-23-2019_08+34AM-ffea133/'\n",
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n", "MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n", "CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n", "OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
@ -116,7 +119,7 @@
"VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/checkpoint_433000.pth.tar\"\n", "VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/checkpoint_433000.pth.tar\"\n",
"VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/config.json\"\n", "VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Models/wavernn/ljspeech/mold_ljspeech_best_model/config.json\"\n",
"VOCODER_CONFIG = load_config(VOCODER_CONFIG_PATH)\n", "VOCODER_CONFIG = load_config(VOCODER_CONFIG_PATH)\n",
"use_cuda = False\n", "use_cuda = True\n",
"\n", "\n",
"# Set some config fields manually for testing\n", "# Set some config fields manually for testing\n",
"# CONFIG.windowing = False\n", "# CONFIG.windowing = False\n",
@ -127,7 +130,7 @@
"# CONFIG.stopnet = True\n", "# CONFIG.stopnet = True\n",
"\n", "\n",
"# Set the vocoder\n", "# Set the vocoder\n",
"use_gl = False # use GL if True\n", "use_gl = True # use GL if True\n",
"batched_wavernn = True # use batched wavernn inference if True" "batched_wavernn = True # use batched wavernn inference if True"
] ]
}, },
@ -138,8 +141,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# LOAD TTS MODEL\n", "# LOAD TTS MODEL\n",
"from utils.text.symbols import symbols, phonemes\n",
"\n",
"# multi speaker \n", "# multi speaker \n",
"if CONFIG.use_speaker_embedding:\n", "if CONFIG.use_speaker_embedding:\n",
" speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n", " speakers = json.load(open(f\"{ROOT_PATH}/speakers.json\", 'r'))\n",
@ -148,6 +149,10 @@
" speakers = []\n", " speakers = []\n",
" speaker_id = None\n", " speaker_id = None\n",
"\n", "\n",
"# if the vocabulary was passed, replace the default\n",
"if 'characters' in CONFIG.keys():\n",
" symbols, phonemes = make_symbols(**CONFIG.characters)\n",
"\n",
"# load the model\n", "# load the model\n",
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
"model = setup_model(num_chars, len(speakers), CONFIG)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n",
@ -181,7 +186,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# LOAD WAVERNN\n", "# LOAD WAVERNN - Make sure you downloaded the model and installed the module\n",
"if use_gl == False:\n", "if use_gl == False:\n",
" from WaveRNN.models.wavernn import Model\n", " from WaveRNN.models.wavernn import Model\n",
" from WaveRNN.utils.audio import AudioProcessor as AudioProcessorVocoder\n", " from WaveRNN.utils.audio import AudioProcessor as AudioProcessorVocoder\n",
@ -533,7 +538,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.3" "version": "3.7.4"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -37,7 +37,7 @@
"from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.visual import plot_spectrogram\n", "from TTS.utils.visual import plot_spectrogram\n",
"from TTS.utils.generic_utils import load_config, setup_model, sequence_mask\n", "from TTS.utils.generic_utils import load_config, setup_model, sequence_mask\n",
"from TTS.utils.text.symbols import symbols, phonemes\n", "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n",
"\n", "\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"\n", "\n",
@ -94,6 +94,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# if the vocabulary was passed, replace the default\n",
"if 'characters' in C.keys():\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"\n",
"# load the model\n", "# load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speaker\n", "# TODO: multiple speaker\n",
@ -116,7 +120,7 @@
"preprocessor = importlib.import_module('TTS.datasets.preprocess')\n", "preprocessor = importlib.import_module('TTS.datasets.preprocess')\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n",
"dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", "dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n",
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
] ]
}, },

View File

@ -100,7 +100,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# LOAD TTS MODEL\n", "# LOAD TTS MODEL\n",
"from TTS.utils.text.symbols import symbols, phonemes\n", "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n",
"\n", "\n",
"# multi speaker \n", "# multi speaker \n",
"if CONFIG.use_speaker_embedding:\n", "if CONFIG.use_speaker_embedding:\n",
@ -110,6 +110,10 @@
" speakers = []\n", " speakers = []\n",
" speaker_id = None\n", " speaker_id = None\n",
"\n", "\n",
"# if the vocabulary was passed, replace the default\n",
"if 'characters' in CONFIG.keys():\n",
" symbols, phonemes = make_symbols(**CONFIG.characters)\n",
"\n",
"# load the model\n", "# load the model\n",
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
"model = setup_model(num_chars, len(speakers), CONFIG)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n",

View File

@ -6,6 +6,10 @@ Instructions below are based on a Ubuntu 18.04 machine, but it should be simple
#### Development server: #### Development server:
##### Using server.py
If you have the environment set already for TTS, then you can directly call ```setup.py```.
##### Using .whl
1. apt-get install -y espeak libsndfile1 python3-venv 1. apt-get install -y espeak libsndfile1 python3-venv
2. python3 -m venv /tmp/venv 2. python3 -m venv /tmp/venv
3. source /tmp/venv/bin/activate 3. source /tmp/venv/bin/activate

View File

@ -14,10 +14,13 @@ def create_argparser():
parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file') parser.add_argument('--tts_checkpoint', type=str, help='path to TTS checkpoint file')
parser.add_argument('--tts_config', type=str, help='path to TTS config.json file') parser.add_argument('--tts_config', type=str, help='path to TTS config.json file')
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
parser.add_argument('--wavernn_lib_path', type=str, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') parser.add_argument('--wavernn_lib_path', type=str, default=None, help='path to WaveRNN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--wavernn_file', type=str, help='path to WaveRNN checkpoint file.') parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.')
parser.add_argument('--wavernn_config', type=str, help='path to WaveRNN config file.') parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.')
parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.')
parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.')
parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.')
parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.')
parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
@ -26,28 +29,35 @@ def create_argparser():
synthesizer = None synthesizer = None
embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model')
checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar')
config_file = os.path.join(embedded_model_folder, 'config.json')
# Default options with embedded model files embedded_tts_folder = os.path.join(embedded_models_folder, 'tts')
if os.path.isfile(checkpoint_file): tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar')
default_tts_checkpoint = checkpoint_file tts_config_file = os.path.join(embedded_tts_folder, 'config.json')
else:
default_tts_checkpoint = None
if os.path.isfile(config_file): embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn')
default_tts_config = config_file wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar')
else: wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json')
default_tts_config = None
embedded_pwgan_folder = os.path.join(embedded_models_folder, 'pwgan')
pwgan_checkpoint_file = os.path.join(embedded_pwgan_folder, 'checkpoint.pkl')
pwgan_config_file = os.path.join(embedded_pwgan_folder, 'config.yml')
args = create_argparser().parse_args() args = create_argparser().parse_args()
# If these were not specified in the CLI args, use default values # If these were not specified in the CLI args, use default values with embedded model files
if not args.tts_checkpoint: if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file):
args.tts_checkpoint = default_tts_checkpoint args.tts_checkpoint = tts_checkpoint_file
if not args.tts_config: if not args.tts_config and os.path.isfile(tts_config_file):
args.tts_config = default_tts_config args.tts_config = tts_config_file
if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file):
args.wavernn_file = wavernn_checkpoint_file
if not args.wavernn_config and os.path.isfile(wavernn_config_file):
args.wavernn_config = wavernn_config_file
if not args.pwgan_file and os.path.isfile(pwgan_checkpoint_file):
args.pwgan_file = pwgan_checkpoint_file
if not args.pwgan_config and os.path.isfile(pwgan_config_file):
args.pwgan_config = pwgan_config_file
synthesizer = Synthesizer(args) synthesizer = Synthesizer(args)

View File

@ -1,17 +1,20 @@
import io import io
import os import re
import sys
import numpy as np import numpy as np
import torch import torch
import sys import yaml
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config, setup_model from TTS.utils.generic_utils import load_config, setup_model
from TTS.utils.text import phonemes, symbols
from TTS.utils.speakers import load_speaker_mapping from TTS.utils.speakers import load_speaker_mapping
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.utils.synthesis import * from TTS.utils.synthesis import *
import re from TTS.utils.text import make_symbols, phonemes, symbols
alphabets = r"([A-Za-z])" alphabets = r"([A-Za-z])"
prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]" prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = r"(Inc|Ltd|Jr|Sr|Co)" suffixes = r"(Inc|Ltd|Jr|Sr|Co)"
@ -23,6 +26,7 @@ websites = r"[.](com|net|org|io|gov)"
class Synthesizer(object): class Synthesizer(object):
def __init__(self, config): def __init__(self, config):
self.wavernn = None self.wavernn = None
self.pwgan = None
self.config = config self.config = config
self.use_cuda = self.config.use_cuda self.use_cuda = self.config.use_cuda
if self.use_cuda: if self.use_cuda:
@ -30,24 +34,34 @@ class Synthesizer(object):
self.load_tts(self.config.tts_checkpoint, self.config.tts_config, self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
self.config.use_cuda) self.config.use_cuda)
if self.config.wavernn_lib_path: if self.config.wavernn_lib_path:
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_path, self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
self.config.wavernn_file, self.config.wavernn_config, self.config.wavernn_config, self.config.use_cuda)
self.config.use_cuda) if self.config.pwgan_lib_path:
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
self.config.pwgan_config, self.config.use_cuda)
def load_tts(self, tts_checkpoint, tts_config, use_cuda): def load_tts(self, tts_checkpoint, tts_config, use_cuda):
# pylint: disable=global-statement
global symbols, phonemes
print(" > Loading TTS model ...") print(" > Loading TTS model ...")
print(" | > model config: ", tts_config) print(" | > model config: ", tts_config)
print(" | > checkpoint file: ", tts_checkpoint) print(" | > checkpoint file: ", tts_checkpoint)
self.tts_config = load_config(tts_config) self.tts_config = load_config(tts_config)
self.use_phonemes = self.tts_config.use_phonemes self.use_phonemes = self.tts_config.use_phonemes
self.ap = AudioProcessor(**self.tts_config.audio) self.ap = AudioProcessor(**self.tts_config.audio)
if 'characters' in self.tts_config.keys():
symbols, phonemes = make_symbols(**self.tts_config.characters)
if self.use_phonemes: if self.use_phonemes:
self.input_size = len(phonemes) self.input_size = len(phonemes)
else: else:
self.input_size = len(symbols) self.input_size = len(symbols)
# load speakers # TODO: fix this for multi-speaker model - load speakers
if self.config.tts_speakers is not None: if self.config.tts_speakers is not None:
self.tts_speakers = load_speaker_mapping(os.path.join(model_path, self.config.tts_speakers)) self.tts_speakers = load_speaker_mapping(self.config.tts_speakers)
num_speakers = len(self.tts_speakers) num_speakers = len(self.tts_speakers)
else: else:
num_speakers = 0 num_speakers = 0
@ -63,16 +77,17 @@ class Synthesizer(object):
if 'r' in cp: if 'r' in cp:
self.tts_model.decoder.set_r(cp['r']) self.tts_model.decoder.set_r(cp['r'])
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda): def load_wavernn(self, lib_path, model_file, model_config, use_cuda):
# TODO: set a function in wavernn code base for model setup and call it here. # TODO: set a function in wavernn code base for model setup and call it here.
sys.path.append(lib_path) # set this if TTS is not installed globally sys.path.append(lib_path) # set this if WaveRNN is not installed globally
#pylint: disable=import-outside-toplevel
from WaveRNN.models.wavernn import Model from WaveRNN.models.wavernn import Model
wavernn_config = os.path.join(model_path, model_config)
model_file = os.path.join(model_path, model_file)
print(" > Loading WaveRNN model ...") print(" > Loading WaveRNN model ...")
print(" | > model config: ", wavernn_config) print(" | > model config: ", model_config)
print(" | > model file: ", model_file) print(" | > model file: ", model_file)
self.wavernn_config = load_config(wavernn_config) self.wavernn_config = load_config(model_config)
# This is the default architecture we use for our models.
# You might need to update it
self.wavernn = Model( self.wavernn = Model(
rnn_dims=512, rnn_dims=512,
fc_dims=512, fc_dims=512,
@ -80,7 +95,7 @@ class Synthesizer(object):
mulaw=self.wavernn_config.mulaw, mulaw=self.wavernn_config.mulaw,
pad=self.wavernn_config.pad, pad=self.wavernn_config.pad,
use_aux_net=self.wavernn_config.use_aux_net, use_aux_net=self.wavernn_config.use_aux_net,
use_upsample_net = self.wavernn_config.use_upsample_net, use_upsample_net=self.wavernn_config.use_upsample_net,
upsample_factors=self.wavernn_config.upsample_factors, upsample_factors=self.wavernn_config.upsample_factors,
feat_dims=80, feat_dims=80,
compute_dims=128, compute_dims=128,
@ -91,18 +106,35 @@ class Synthesizer(object):
).cuda() ).cuda()
check = torch.load(model_file) check = torch.load(model_file)
self.wavernn.load_state_dict(check['model']) self.wavernn.load_state_dict(check['model'], map_location="cpu")
if use_cuda: if use_cuda:
self.wavernn.cuda() self.wavernn.cuda()
self.wavernn.eval() self.wavernn.eval()
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
sys.path.append(lib_path) # set this if ParallelWaveGAN is not installed globally
#pylint: disable=import-outside-toplevel
from parallel_wavegan.models import ParallelWaveGANGenerator
print(" > Loading PWGAN model ...")
print(" | > model config: ", model_config)
print(" | > model file: ", model_file)
with open(model_config) as f:
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
self.pwgan.remove_weight_norm()
if use_cuda:
self.pwgan.cuda()
self.pwgan.eval()
def save_wav(self, wav, path): def save_wav(self, wav, path):
# wav *= 32767 / max(1e-8, np.max(np.abs(wav))) # wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
wav = np.array(wav) wav = np.array(wav)
self.ap.save_wav(wav, path) self.ap.save_wav(wav, path)
def split_into_sentences(self, text): @staticmethod
text = " " + text + " " def split_into_sentences(text):
text = " " + text + " <stop>"
text = text.replace("\n", " ") text = text.replace("\n", " ")
text = re.sub(prefixes, "\\1<prd>", text) text = re.sub(prefixes, "\\1<prd>", text)
text = re.sub(websites, "<prd>\\1", text) text = re.sub(websites, "<prd>\\1", text)
@ -129,15 +161,13 @@ class Synthesizer(object):
text = text.replace("<prd>", ".") text = text.replace("<prd>", ".")
sentences = text.split("<stop>") sentences = text.split("<stop>")
sentences = sentences[:-1] sentences = sentences[:-1]
sentences = [s.strip() for s in sentences] sentences = list(filter(None, [s.strip() for s in sentences])) # remove empty sentences
return sentences return sentences
def tts(self, text): def tts(self, text):
wavs = [] wavs = []
sens = self.split_into_sentences(text) sens = self.split_into_sentences(text)
print(sens) print(sens)
if not sens:
sens = [text+'.']
for sen in sens: for sen in sens:
# preprocess the given text # preprocess the given text
inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda) inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda)
@ -148,9 +178,16 @@ class Synthesizer(object):
postnet_output, decoder_output, _ = parse_outputs( postnet_output, decoder_output, _ = parse_outputs(
postnet_output, decoder_output, alignments) postnet_output, decoder_output, alignments)
if self.wavernn: if self.pwgan:
postnet_output = postnet_output[0].data.cpu().numpy() vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
wav = self.wavernn.generate(torch.FloatTensor(postnet_output.T).unsqueeze(0).cuda(), batched=self.config.is_wavernn_batched, target=11000, overlap=550) if self.use_cuda:
vocoder_input.cuda()
wav = self.pwgan.inference(vocoder_input, hop_size=self.ap.hop_length)
elif self.wavernn:
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
if self.use_cuda:
vocoder_input.cuda()
wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550)
else: else:
wav = inv_spectrogram(postnet_output, self.ap, self.tts_config) wav = inv_spectrogram(postnet_output, self.ap, self.tts_config)
# trim silence # trim silence

View File

@ -61,10 +61,11 @@ package_data = ['server/templates/*']
if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config: if 'bdist_wheel' in unknown_args and args.checkpoint and args.model_config:
print('Embedding model in wheel file...') print('Embedding model in wheel file...')
model_dir = os.path.join('server', 'model') model_dir = os.path.join('server', 'model')
os.makedirs(model_dir, exist_ok=True) tts_dir = os.path.join(model_dir, 'tts')
embedded_checkpoint_path = os.path.join(model_dir, 'checkpoint.pth.tar') os.makedirs(tts_dir, exist_ok=True)
embedded_checkpoint_path = os.path.join(tts_dir, 'checkpoint.pth.tar')
shutil.copy(args.checkpoint, embedded_checkpoint_path) shutil.copy(args.checkpoint, embedded_checkpoint_path)
embedded_config_path = os.path.join(model_dir, 'config.json') embedded_config_path = os.path.join(tts_dir, 'config.json')
shutil.copy(args.model_config, embedded_config_path) shutil.copy(args.model_config, embedded_config_path)
package_data.extend([embedded_checkpoint_path, embedded_config_path]) package_data.extend([embedded_checkpoint_path, embedded_config_path])

View File

@ -1,3 +1,4 @@
# pylint: disable=redefined-outer-name, unused-argument
import os import os
import time import time
import argparse import argparse
@ -7,7 +8,7 @@ import string
from TTS.utils.synthesis import synthesis from TTS.utils.synthesis import synthesis
from TTS.utils.generic_utils import load_config, setup_model from TTS.utils.generic_utils import load_config, setup_model
from TTS.utils.text.symbols import symbols, phonemes from TTS.utils.text.symbols import make_symbols, symbols, phonemes
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
@ -25,14 +26,16 @@ def tts(model,
t_1 = time.time() t_1 = time.time()
use_vocoder_model = vocoder_model is not None use_vocoder_model = vocoder_model is not None
waveform, alignment, _, postnet_output, stop_tokens = synthesis( waveform, alignment, _, postnet_output, stop_tokens = synthesis(
model, text, C, use_cuda, ap, speaker_id, False, model, text, C, use_cuda, ap, speaker_id, style_wav=False,
C.enable_eos_bos_chars) truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars,
use_griffin_lim=(not use_vocoder_model), do_trim_silence=True)
if C.model == "Tacotron" and use_vocoder_model: if C.model == "Tacotron" and use_vocoder_model:
postnet_output = ap.out_linear_to_mel(postnet_output.T).T postnet_output = ap.out_linear_to_mel(postnet_output.T).T
# correct if there is a scale difference b/w two models # correct if there is a scale difference b/w two models
postnet_output = ap._denormalize(postnet_output)
postnet_output = ap_vocoder._normalize(postnet_output)
if use_vocoder_model: if use_vocoder_model:
postnet_output = ap._denormalize(postnet_output)
postnet_output = ap_vocoder._normalize(postnet_output)
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0) vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
waveform = vocoder_model.generate( waveform = vocoder_model.generate(
vocoder_input.cuda() if use_cuda else vocoder_input, vocoder_input.cuda() if use_cuda else vocoder_input,
@ -45,6 +48,8 @@ def tts(model,
if __name__ == "__main__": if __name__ == "__main__":
global symbols, phonemes
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('text', type=str, help='Text to generate speech.') parser.add_argument('text', type=str, help='Text to generate speech.')
parser.add_argument('config_path', parser.add_argument('config_path',
@ -58,7 +63,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
'out_path', 'out_path',
type=str, type=str,
help='Path to save final wav file.', help='Path to save final wav file. Wav file will be names as the text given.',
) )
parser.add_argument('--use_cuda', parser.add_argument('--use_cuda',
type=bool, type=bool,
@ -102,6 +107,10 @@ if __name__ == "__main__":
# load the audio processor # load the audio processor
ap = AudioProcessor(**C.audio) ap = AudioProcessor(**C.audio)
# if the vocabulary was passed, replace the default
if 'characters' in C.keys():
symbols, phonemes = make_symbols(**C.characters)
# load speakers # load speakers
if args.speakers_json != '': if args.speakers_json != '':
speakers = json.load(open(args.speakers_json, 'r')) speakers = json.load(open(args.speakers_json, 'r'))

View File

@ -3,9 +3,11 @@
"tts_config":"dummy_model_config.json", // tts config.json file "tts_config":"dummy_model_config.json", // tts config.json file
"tts_speakers": null, // json file listing speaker ids. null if no speaker embedding. "tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis. "wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
"wavernn_path": null, // wavernn model root path
"wavernn_file": null, // wavernn checkpoint file name "wavernn_file": null, // wavernn checkpoint file name
"wavernn_config": null, // wavernn config file "wavernn_config": null, // wavernn config file
"pwgan_lib_path": null,
"pwgan_file": null,
"pwgan_config": null,
"is_wavernn_batched":true, "is_wavernn_batched":true,
"port": 5002, "port": 5002,
"use_cuda": false, "use_cuda": false,

View File

@ -19,6 +19,16 @@
"mel_fmax": 7600, // maximum freq level for mel-spec. Tune for dataset!! "mel_fmax": 7600, // maximum freq level for mel-spec. Tune for dataset!!
"do_trim_silence": false "do_trim_silence": false
}, },
"characters":{
"pad": "_",
"eos": "~",
"bos": "^",
"characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ",
"punctuations":"!'(),-.:;? ",
"phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ"
},
"hidden_size": 128, "hidden_size": 128,
"embedding_size": 256, "embedding_size": 256,
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",

View File

@ -5,13 +5,19 @@ import torch as T
from TTS.server.synthesizer import Synthesizer from TTS.server.synthesizer import Synthesizer
from TTS.tests import get_tests_input_path, get_tests_output_path from TTS.tests import get_tests_input_path, get_tests_output_path
from TTS.utils.text.symbols import phonemes, symbols from TTS.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model
class DemoServerTest(unittest.TestCase): class DemoServerTest(unittest.TestCase):
# pylint: disable=R0201
def _create_random_model(self): def _create_random_model(self):
# pylint: disable=global-statement
global symbols, phonemes
config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json')) config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json'))
if 'characters' in config.keys():
symbols, phonemes = make_symbols(**config.characters)
num_chars = len(phonemes) if config.use_phonemes else len(symbols) num_chars = len(phonemes) if config.use_phonemes else len(symbols)
model = setup_model(num_chars, 0, config) model = setup_model(num_chars, 0, config)
output_path = os.path.join(get_tests_output_path()) output_path = os.path.join(get_tests_output_path())

View File

@ -38,6 +38,7 @@ class TestTTSDataset(unittest.TestCase):
c.text_cleaner, c.text_cleaner,
ap=self.ap, ap=self.ap,
meta_data=items, meta_data=items,
tp=c.characters if 'characters' in c.keys() else None,
batch_group_size=bgs, batch_group_size=bgs,
min_seq_len=c.min_seq_len, min_seq_len=c.min_seq_len,
max_seq_len=float("inf"), max_seq_len=float("inf"),
@ -137,9 +138,7 @@ class TestTTSDataset(unittest.TestCase):
# NOTE: Below needs to check == 0 but due to an unknown reason # NOTE: Below needs to check == 0 but due to an unknown reason
# there is a slight difference between two matrices. # there is a slight difference between two matrices.
# TODO: Check this assert cond more in detail. # TODO: Check this assert cond more in detail.
assert abs((abs(mel.T) assert abs(mel.T - mel_dl).max() < 1e-5, abs(mel.T - mel_dl).max()
- abs(mel_dl)
).sum()) < 1e-5, (abs(mel.T) - abs(mel_dl)).sum()
# check mel-spec correctness # check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy() mel_spec = mel_input[0].cpu().numpy()

View File

@ -11,7 +11,7 @@ source /tmp/venv/bin/activate
pip install --quiet --upgrade pip setuptools wheel pip install --quiet --upgrade pip setuptools wheel
rm -f dist/*.whl rm -f dist/*.whl
python setup.py bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json python setup.py --quiet bdist_wheel --checkpoint tests/outputs/checkpoint_10.pth.tar --model_config tests/outputs/dummy_model_config.json
pip install --quiet dist/TTS*.whl pip install --quiet dist/TTS*.whl
python -m TTS.server.server & python -m TTS.server.server &

View File

@ -1,7 +1,14 @@
import os
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
# pylint: disable=unused-import
import unittest import unittest
import torch as T
from TTS.utils.text import * from TTS.utils.text import *
from TTS.tests import get_tests_path
from TTS.utils.generic_utils import load_config
TESTS_PATH = get_tests_path()
conf = load_config(os.path.join(TESTS_PATH, 'test_config.json'))
def test_phoneme_to_sequence(): def test_phoneme_to_sequence():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
@ -9,67 +16,80 @@ def test_phoneme_to_sequence():
lang = "en-us" lang = "en-us"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!" gt = "ɹiːsənt ɹɪːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
assert text_hat == gt assert text_hat == text_hat_with_params == gt
# multiple punctuations # multiple punctuations
text = "Be a voice, not an! echo?" text = "Be a voice, not an! echo?"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?" gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ?"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == text_hat_with_params == gt
# not ending with punctuation # not ending with punctuation
text = "Be a voice, not an! echo" text = "Be a voice, not an! echo"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == text_hat_with_params == gt
# original # original
text = "Be a voice, not an echo!" text = "Be a voice, not an echo!"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!" gt = "biː ɐ vɔɪs, nɑːt ɐn ɛkoʊ!"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == text_hat_with_params == gt
# extra space after the sentence # extra space after the sentence
text = "Be a voice, not an! echo. " text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ." gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ."
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == text_hat_with_params == gt
# extra space after the sentence # extra space after the sentence
text = "Be a voice, not an! echo. " text = "Be a voice, not an! echo. "
sequence = phoneme_to_sequence(text, text_cleaner, lang, True) sequence = phoneme_to_sequence(text, text_cleaner, lang, True)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~" gt = "^biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ.~"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == text_hat_with_params == gt
# padding char # padding char
text = "_Be a _voice, not an! echo_" text = "_Be a _voice, not an! echo_"
sequence = phoneme_to_sequence(text, text_cleaner, lang) sequence = phoneme_to_sequence(text, text_cleaner, lang)
text_hat = sequence_to_phoneme(sequence) text_hat = sequence_to_phoneme(sequence)
sequence_with_params = phoneme_to_sequence(text, text_cleaner, lang, tp=conf.characters)
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ"
print(text_hat) print(text_hat)
print(len(sequence)) print(len(sequence))
assert text_hat == gt assert text_hat == text_hat_with_params == gt
def test_text2phone(): def test_text2phone():
text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!"
gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!" gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!"
lang = "en-us" lang = "en-us"
phonemes = text2phone(text, lang) ph = text2phone(text, lang)
assert gt == phonemes, f"\n{phonemes} \n vs \n{gt}" assert gt == ph, f"\n{phonemes} \n vs \n{gt}"

View File

@ -20,12 +20,12 @@ from TTS.utils.generic_utils import (
get_git_branch, load_config, remove_experiment_folder, save_best_model, get_git_branch, load_config, remove_experiment_folder, save_best_model,
save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file, save_checkpoint, adam_weight_decay, set_init_dict, copy_config_file,
setup_model, gradual_training_scheduler, KeepAverage, setup_model, gradual_training_scheduler, KeepAverage,
set_weight_decay) set_weight_decay, check_config)
from TTS.utils.logger import Logger from TTS.utils.logger import Logger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
get_speakers get_speakers
from TTS.utils.synthesis import synthesis from TTS.utils.synthesis import synthesis
from TTS.utils.text.symbols import phonemes, symbols from TTS.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.visual import plot_alignment, plot_spectrogram
from TTS.datasets.preprocess import load_meta_data from TTS.datasets.preprocess import load_meta_data
from TTS.utils.radam import RAdam from TTS.utils.radam import RAdam
@ -49,6 +49,7 @@ def setup_loader(ap, r, is_val=False, verbose=False):
c.text_cleaner, c.text_cleaner,
meta_data=meta_data_eval if is_val else meta_data_train, meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap, ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
batch_group_size=0 if is_val else c.batch_group_size * batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size, c.batch_size,
min_seq_len=c.min_seq_len, min_seq_len=c.min_seq_len,
@ -515,9 +516,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
# FIXME: move args definition/parsing inside of main? # FIXME: move args definition/parsing inside of main?
def main(args): # pylint: disable=redefined-outer-name def main(args): # pylint: disable=redefined-outer-name
global meta_data_train, meta_data_eval # pylint: disable=global-variable-undefined
global meta_data_train, meta_data_eval, symbols, phonemes
# Audio processor # Audio processor
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
symbols, phonemes = make_symbols(**c.characters)
# DISTRUBUTED # DISTRUBUTED
if num_gpus > 1: if num_gpus > 1:
@ -687,6 +691,7 @@ if __name__ == '__main__':
# setup output paths and read configs # setup output paths and read configs
c = load_config(args.config_path) c = load_config(args.config_path)
check_config(c)
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
OUT_PATH = args.continue_path OUT_PATH = args.continue_path

View File

@ -389,3 +389,131 @@ class KeepAverage():
def update_values(self, value_dict): def update_values(self, value_dict):
for key, value in value_dict.items(): for key, value in value_dict.items():
self.update_value(key, value) self.update_value(key, value)
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None):
if restricted:
assert name in c.keys(), f' [!] {name} not defined in config.json'
if name in c.keys():
if max_val:
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
if min_val:
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
if enum_list:
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
if val_type:
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
def check_config(c):
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
_check_argument('run_name', c, restricted=True, val_type=str)
_check_argument('run_description', c, val_type=str)
# AUDIO
_check_argument('audio', c, restricted=True, val_type=dict)
# audio processing parameters
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
_check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000)
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000)
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
_check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
_check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
# vocabulary parameters
_check_argument('characters', c, restricted=False, val_type=dict)
_check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
_check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
_check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
_check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
_check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
_check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
# normalization parameters
_check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
_check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
_check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
_check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
_check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
_check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
_check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
_check_argument('trim_db', c['audio'], restricted=True, val_type=int)
# training parameters
_check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
_check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
_check_argument('r', c, restricted=True, val_type=int, min_val=1)
_check_argument('gradual_training', c, restricted=False, val_type=list)
_check_argument('loss_masking', c, restricted=True, val_type=bool)
# _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# validation parameters
_check_argument('run_eval', c, restricted=True, val_type=bool)
_check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
_check_argument('test_sentences_file', c, restricted=False, val_type=str)
# optimizer
_check_argument('noam_schedule', c, restricted=False, val_type=bool)
_check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
_check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
_check_argument('lr', c, restricted=True, val_type=float, min_val=0)
_check_argument('wd', c, restricted=True, val_type=float, min_val=0)
_check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
_check_argument('seq_len_norm', c, restricted=True, val_type=bool)
# tacotron prenet
_check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
_check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
_check_argument('prenet_dropout', c, restricted=True, val_type=bool)
# attention
_check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
_check_argument('attention_heads', c, restricted=True, val_type=int)
_check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
_check_argument('windowing', c, restricted=True, val_type=bool)
_check_argument('use_forward_attn', c, restricted=True, val_type=bool)
_check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
_check_argument('transition_agent', c, restricted=True, val_type=bool)
_check_argument('transition_agent', c, restricted=True, val_type=bool)
_check_argument('location_attn', c, restricted=True, val_type=bool)
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
# stopnet
_check_argument('stopnet', c, restricted=True, val_type=bool)
_check_argument('separate_stopnet', c, restricted=True, val_type=bool)
# tensorboard
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
_check_argument('checkpoint', c, restricted=True, val_type=bool)
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# dataloading
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=['english_cleaners', 'phoneme_cleaners', 'transliteration_cleaners', 'basic_cleaners'])
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
_check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
_check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
_check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
_check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
_check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
# paths
_check_argument('output_path', c, restricted=True, val_type=str)
# multi-speaker gst
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
_check_argument('style_wav_for_test', c, restricted=True, val_type=str)
_check_argument('use_gst', c, restricted=True, val_type=bool)
# datasets - checking only the first entry
_check_argument('datasets', c, restricted=True, val_type=list)
for dataset_entry in c['datasets']:
_check_argument('name', dataset_entry, restricted=True, val_type=str)
_check_argument('path', dataset_entry, restricted=True, val_type=str)
_check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str)
_check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)

View File

@ -9,10 +9,11 @@ def text_to_seqvec(text, CONFIG, use_cuda):
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
seq = np.asarray( seq = np.asarray(
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars), CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None),
dtype=np.int32) dtype=np.int32)
else: else:
seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32)
# torch tensor # torch tensor
chars_var = torch.from_numpy(seq).unsqueeze(0) chars_var = torch.from_numpy(seq).unsqueeze(0)
if use_cuda: if use_cuda:
@ -78,6 +79,7 @@ def synthesis(model,
style_wav=None, style_wav=None,
truncated=False, truncated=False,
enable_eos_bos_chars=False, #pylint: disable=unused-argument enable_eos_bos_chars=False, #pylint: disable=unused-argument
use_griffin_lim=False,
do_trim_silence=False): do_trim_silence=False):
"""Synthesize voice for the given text. """Synthesize voice for the given text.
@ -111,8 +113,10 @@ def synthesis(model,
postnet_output, decoder_output, alignment = parse_outputs( postnet_output, decoder_output, alignment = parse_outputs(
postnet_output, decoder_output, alignments) postnet_output, decoder_output, alignments)
# plot results # plot results
wav = inv_spectrogram(postnet_output, ap, CONFIG) wav = None
# trim silence if use_griffin_lim:
if do_trim_silence: wav = inv_spectrogram(postnet_output, ap, CONFIG)
wav = trim_silence(wav, ap) # trim silence
if do_trim_silence:
wav = trim_silence(wav, ap)
return wav, alignment, decoder_output, postnet_output, stop_tokens return wav, alignment, decoder_output, postnet_output, stop_tokens

View File

@ -4,15 +4,15 @@ import re
import phonemizer import phonemizer
from phonemizer.phonemize import phonemize from phonemizer.phonemize import phonemize
from TTS.utils.text import cleaners from TTS.utils.text import cleaners
from TTS.utils.text.symbols import symbols, phonemes, _phoneme_punctuations, _bos, \ from TTS.utils.text.symbols import make_symbols, symbols, phonemes, _phoneme_punctuations, _bos, \
_eos _eos
# Mappings from symbol to numeric ID and vice versa: # Mappings from symbol to numeric ID and vice versa:
_SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)} _symbol_to_id = {s: i for i, s in enumerate(symbols)}
_ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)}
_PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)} _phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
_ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)} _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
# Regular expression matching text enclosed in curly braces: # Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)') _CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
@ -38,14 +38,11 @@ def text2phone(text, language):
if text[-1] == punctuations[-1]: if text[-1] == punctuations[-1]:
for punct in punctuations[:-1]: for punct in punctuations[:-1]:
ph = ph.replace('| |\n', '|'+punct+'| |', 1) ph = ph.replace('| |\n', '|'+punct+'| |', 1)
try:
ph = ph + punctuations[-1] ph = ph + punctuations[-1]
except:
print(text)
else: else:
for punct in punctuations: for punct in punctuations:
ph = ph.replace('| |\n', '|'+punct+'| |', 1) ph = ph.replace('| |\n', '|'+punct+'| |', 1)
elif float(phonemizer.__version__) == 2.1: elif float(phonemizer.__version__) > 2.1:
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True) ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True)
# this is a simple fix for phonemizer. # this is a simple fix for phonemizer.
# https://github.com/bootphon/phonemizer/issues/32 # https://github.com/bootphon/phonemizer/issues/32
@ -59,11 +56,25 @@ def text2phone(text, language):
return ph return ph
def pad_with_eos_bos(phoneme_sequence): def pad_with_eos_bos(phoneme_sequence, tp=None):
return [_PHONEMES_TO_ID[_bos]] + list(phoneme_sequence) + [_PHONEMES_TO_ID[_eos]] # pylint: disable=global-statement
global _phonemes_to_id, _bos, _eos
if tp:
_bos = tp['bos']
_eos = tp['eos']
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False): def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None):
# pylint: disable=global-statement
global _phonemes_to_id
if tp:
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
sequence = [] sequence = []
text = text.replace(":", "") text = text.replace(":", "")
clean_text = _clean_text(text, cleaner_names) clean_text = _clean_text(text, cleaner_names)
@ -75,21 +86,27 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False):
sequence += _phoneme_to_sequence(phoneme) sequence += _phoneme_to_sequence(phoneme)
# Append EOS char # Append EOS char
if enable_eos_bos: if enable_eos_bos:
sequence = pad_with_eos_bos(sequence) sequence = pad_with_eos_bos(sequence, tp=tp)
return sequence return sequence
def sequence_to_phoneme(sequence): def sequence_to_phoneme(sequence, tp=None):
# pylint: disable=global-statement
'''Converts a sequence of IDs back to a string''' '''Converts a sequence of IDs back to a string'''
global _id_to_phonemes
result = '' result = ''
if tp:
_, _phonemes = make_symbols(**tp)
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
for symbol_id in sequence: for symbol_id in sequence:
if symbol_id in _ID_TO_PHONEMES: if symbol_id in _id_to_phonemes:
s = _ID_TO_PHONEMES[symbol_id] s = _id_to_phonemes[symbol_id]
result += s result += s
return result.replace('}{', ' ') return result.replace('}{', ' ')
def text_to_sequence(text, cleaner_names): def text_to_sequence(text, cleaner_names, tp=None):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded The text can optionally have ARPAbet sequences enclosed in curly braces embedded
@ -102,6 +119,12 @@ def text_to_sequence(text, cleaner_names):
Returns: Returns:
List of integers corresponding to the symbols in the text List of integers corresponding to the symbols in the text
''' '''
# pylint: disable=global-statement
global _symbol_to_id
if tp:
_symbols, _ = make_symbols(**tp)
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
sequence = [] sequence = []
# Check for curly braces and treat their contents as ARPAbet: # Check for curly braces and treat their contents as ARPAbet:
while text: while text:
@ -116,12 +139,18 @@ def text_to_sequence(text, cleaner_names):
return sequence return sequence
def sequence_to_text(sequence): def sequence_to_text(sequence, tp=None):
'''Converts a sequence of IDs back to a string''' '''Converts a sequence of IDs back to a string'''
# pylint: disable=global-statement
global _id_to_symbol
if tp:
_symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = '' result = ''
for symbol_id in sequence: for symbol_id in sequence:
if symbol_id in _ID_TO_SYMBOL: if symbol_id in _id_to_symbol:
s = _ID_TO_SYMBOL[symbol_id] s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces: # Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@': if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:] s = '{%s}' % s[1:]
@ -139,11 +168,11 @@ def _clean_text(text, cleaner_names):
def _symbols_to_sequence(syms): def _symbols_to_sequence(syms):
return [_SYMBOL_TO_ID[s] for s in syms if _should_keep_symbol(s)] return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)]
def _phoneme_to_sequence(phons): def _phoneme_to_sequence(phons):
return [_PHONEMES_TO_ID[s] for s in list(phons) if _should_keep_phoneme(s)] return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)]
def _arpabet_to_sequence(text): def _arpabet_to_sequence(text):
@ -151,8 +180,8 @@ def _arpabet_to_sequence(text):
def _should_keep_symbol(s): def _should_keep_symbol(s):
return s in _SYMBOL_TO_ID and s not in ['~', '^', '_'] return s in _symbol_to_id and s not in ['~', '^', '_']
def _should_keep_phoneme(p): def _should_keep_phoneme(p):
return p in _PHONEMES_TO_ID and p not in ['~', '^', '_'] return p in _phonemes_to_id and p not in ['~', '^', '_']

View File

@ -5,6 +5,18 @@ Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
''' '''
def make_symbols(characters, phonemes, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name
''' Function to create symbols and phonemes '''
_phonemes_sorted = sorted(list(phonemes))
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in _phonemes_sorted]
# Export all symbols:
_symbols = [pad, eos, bos] + list(characters) + _arpabet
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
return _symbols, _phonemes
_pad = '_' _pad = '_'
_eos = '~' _eos = '~'
@ -20,14 +32,9 @@ _pulmonic_consonants = 'pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðsz
_suprasegmentals = 'ˈˌːˑ' _suprasegmentals = 'ˈˌːˑ'
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
_diacrilics = 'ɚ˞ɫ' _diacrilics = 'ɚ˞ɫ'
_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics)) _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
_arpabet = ['@' + s for s in _phonemes]
# Export all symbols:
symbols = [_pad, _eos, _bos] + list(_characters) + _arpabet
phonemes = [_pad, _eos, _bos] + list(_phonemes) + list(_punctuations)
# Generate ALIEN language # Generate ALIEN language
# from random import shuffle # from random import shuffle

View File

@ -54,9 +54,10 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
plt.xlabel("Decoder timestamp", fontsize=label_fontsize) plt.xlabel("Decoder timestamp", fontsize=label_fontsize)
plt.ylabel("Encoder timestamp", fontsize=label_fontsize) plt.ylabel("Encoder timestamp", fontsize=label_fontsize)
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars) seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
text = sequence_to_phoneme(seq) text = sequence_to_phoneme(seq, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
print(text) print(text)
plt.yticks(range(len(text)), list(text)) plt.yticks(range(len(text)), list(text))
plt.colorbar() plt.colorbar()