From 4e53896438b5365269e54dae999b6ddab837b0c4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Mon, 2 Mar 2020 15:33:13 -0300 Subject: [PATCH] fix travis lint check --- datasets/TTSDataset.py | 2 +- notebooks/Benchmark-PWGAN.ipynb | 2 +- notebooks/Benchmark.ipynb | 2 +- notebooks/TestAttention.ipynb | 2 +- server/synthesizer.py | 12 ++++--- synthesize.py | 2 +- tests/test_demo_server.py | 5 ++- tests/test_loader.py | 2 +- train.py | 1 + utils/text/__init__.py | 55 ++++++++++++++++++--------------- utils/text/symbols.py | 4 +-- utils/visual.py | 2 +- 12 files changed, 52 insertions(+), 39 deletions(-) diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index d649bf23..d3a6f486 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -195,7 +195,7 @@ class MyDataset(Dataset): mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] linear = [self.ap.spectrogram(w).astype('float32') for w in wav] - mel_lengths = [m.shape[1] for m in mel] + mel_lengths = [m.shape[1] for m in mel] # compute 'stop token' targets stop_targets = [ diff --git a/notebooks/Benchmark-PWGAN.ipynb b/notebooks/Benchmark-PWGAN.ipynb index 4a2a21d7..19a1a79c 100644 --- a/notebooks/Benchmark-PWGAN.ipynb +++ b/notebooks/Benchmark-PWGAN.ipynb @@ -144,7 +144,7 @@ "\n", "# if the vocabulary was passed, replace the default\n", "if 'text' in CONFIG.keys():\n", - " symbols, phonemes = make_symbols(**CONFIG.text)\n", + " symbols, phonemes = make_symbols(**CONFIG.text)\n", "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", diff --git a/notebooks/Benchmark.ipynb b/notebooks/Benchmark.ipynb index 528d7a3b..bf6f2774 100644 --- a/notebooks/Benchmark.ipynb +++ b/notebooks/Benchmark.ipynb @@ -151,7 +151,7 @@ "\n", "# if the vocabulary was passed, replace the default\n", "if 'text' in CONFIG.keys():\n", - " symbols, phonemes = make_symbols(**CONFIG.text)\n", + " symbols, phonemes = make_symbols(**CONFIG.text)\n", "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb index 5310fb92..b0599d80 100644 --- a/notebooks/TestAttention.ipynb +++ b/notebooks/TestAttention.ipynb @@ -112,7 +112,7 @@ "\n", "# if the vocabulary was passed, replace the default\n", "if 'text' in CONFIG.keys():\n", - " symbols, phonemes = make_symbols(**CONFIG.text)\n", + " symbols, phonemes = make_symbols(**CONFIG.text)\n", "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", diff --git a/server/synthesizer.py b/server/synthesizer.py index f001afcd..f0921513 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -9,7 +9,10 @@ import yaml from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import load_config, setup_model from TTS.utils.speakers import load_speaker_mapping +# pylint: disable=unused-wildcard-import +# pylint: disable=wildcard-import from TTS.utils.synthesis import * + from TTS.utils.text import make_symbols, phonemes, symbols alphabets = r"([A-Za-z])" @@ -38,18 +41,19 @@ class Synthesizer(object): self.config.pwgan_config, self.config.use_cuda) def load_tts(self, tts_checkpoint, tts_config, use_cuda): + # pylint: disable=global-statement global symbols, phonemes print(" > Loading TTS model ...") print(" | > model config: ", tts_config) print(" | > checkpoint file: ", tts_checkpoint) - + self.tts_config = load_config(tts_config) self.use_phonemes = self.tts_config.use_phonemes self.ap = AudioProcessor(**self.tts_config.audio) if 'text' in self.tts_config.keys(): - symbols, phonemes = make_symbols(**self.tts_config.text) + symbols, phonemes = make_symbols(**self.tts_config.text) if self.use_phonemes: self.input_size = len(phonemes) @@ -61,7 +65,7 @@ class Synthesizer(object): num_speakers = len(self.tts_speakers) else: num_speakers = 0 - self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) + self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config) # load model state cp = torch.load(tts_checkpoint, map_location=torch.device('cpu')) # load the model @@ -91,7 +95,7 @@ class Synthesizer(object): mulaw=self.wavernn_config.mulaw, pad=self.wavernn_config.pad, use_aux_net=self.wavernn_config.use_aux_net, - use_upsample_net = self.wavernn_config.use_upsample_net, + use_upsample_net=self.wavernn_config.use_upsample_net, upsample_factors=self.wavernn_config.upsample_factors, feat_dims=80, compute_dims=128, diff --git a/synthesize.py b/synthesize.py index d294701f..6f3a235f 100644 --- a/synthesize.py +++ b/synthesize.py @@ -109,7 +109,7 @@ if __name__ == "__main__": # if the vocabulary was passed, replace the default if 'text' in C.keys(): - symbols, phonemes = make_symbols(**C.text) + symbols, phonemes = make_symbols(**C.text) # load speakers if args.speakers_json != '': diff --git a/tests/test_demo_server.py b/tests/test_demo_server.py index 3e360e20..36848942 100644 --- a/tests/test_demo_server.py +++ b/tests/test_demo_server.py @@ -10,10 +10,13 @@ from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model class DemoServerTest(unittest.TestCase): + # pylint: disable=R0201 def _create_random_model(self): + # pylint: disable=global-statement + global symbols, phonemes config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json')) if 'text' in config.keys(): - symbols, phonemes = make_symbols(**config.text) + symbols, phonemes = make_symbols(**config.text) num_chars = len(phonemes) if config.use_phonemes else len(symbols) model = setup_model(num_chars, 0, config) diff --git a/tests/test_loader.py b/tests/test_loader.py index 5141fa85..eb23ed19 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -37,7 +37,7 @@ class TestTTSDataset(unittest.TestCase): r, c.text_cleaner, ap=self.ap, - meta_data=items, + meta_data=items, tp=c.text if 'text' in c.keys() else None, batch_group_size=bgs, min_seq_len=c.min_seq_len, diff --git a/train.py b/train.py index 616d54ac..bf5429e9 100644 --- a/train.py +++ b/train.py @@ -516,6 +516,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined global meta_data_train, meta_data_eval, symbols, phonemes # Audio processor ap = AudioProcessor(**c.audio) diff --git a/utils/text/__init__.py b/utils/text/__init__.py index ff21ffe0..4361bc13 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -8,11 +8,11 @@ from TTS.utils.text.symbols import make_symbols, symbols, phonemes, _phoneme_pun _eos # Mappings from symbol to numeric ID and vice versa: -_SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)} -_ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)} +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} -_PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)} -_ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)} +_phonemes_to_id = {s: i for i, s in enumerate(phonemes)} +_id_to_phonemes = {i: s for i, s in enumerate(phonemes)} # Regular expression matching text enclosed in curly braces: _CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)') @@ -57,21 +57,23 @@ def text2phone(text, language): def pad_with_eos_bos(phoneme_sequence, tp=None): - global _PHONEMES_TO_ID, _bos, _eos + # pylint: disable=global-statement + global _phonemes_to_id, _bos, _eos if tp: _bos = tp['bos'] _eos = tp['eos'] _, _phonemes = make_symbols(**tp) - _PHONEMES_TO_ID = {s: i for i, s in enumerate(_phonemes)} - - return [_PHONEMES_TO_ID[_bos]] + list(phoneme_sequence) + [_PHONEMES_TO_ID[_eos]] + _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} + + return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None): - global _PHONEMES_TO_ID + # pylint: disable=global-statement + global _phonemes_to_id if tp: _, _phonemes = make_symbols(**tp) - _PHONEMES_TO_ID = {s: i for i, s in enumerate(_phonemes)} + _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} sequence = [] text = text.replace(":", "") @@ -89,16 +91,17 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp= def sequence_to_phoneme(sequence, tp=None): + # pylint: disable=global-statement '''Converts a sequence of IDs back to a string''' - global _ID_TO_PHONEMES + global _id_to_phonemes result = '' if tp: _, _phonemes = make_symbols(**tp) - _ID_TO_PHONEMES = {i: s for i, s in enumerate(_phonemes)} - + _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} + for symbol_id in sequence: - if symbol_id in _ID_TO_PHONEMES: - s = _ID_TO_PHONEMES[symbol_id] + if symbol_id in _id_to_phonemes: + s = _id_to_phonemes[symbol_id] result += s return result.replace('}{', ' ') @@ -116,10 +119,11 @@ def text_to_sequence(text, cleaner_names, tp=None): Returns: List of integers corresponding to the symbols in the text ''' - global _SYMBOL_TO_ID + # pylint: disable=global-statement + global _symbol_to_id if tp: _symbols, _ = make_symbols(**tp) - _SYMBOL_TO_ID = {s: i for i, s in enumerate(_symbols)} + _symbol_to_id = {s: i for i, s in enumerate(_symbols)} sequence = [] # Check for curly braces and treat their contents as ARPAbet: @@ -137,15 +141,16 @@ def text_to_sequence(text, cleaner_names, tp=None): def sequence_to_text(sequence, tp=None): '''Converts a sequence of IDs back to a string''' - global _ID_TO_SYMBOL + # pylint: disable=global-statement + global _id_to_symbol if tp: _symbols, _ = make_symbols(**tp) - _ID_TO_SYMBOL = {i: s for i, s in enumerate(_symbols)} + _id_to_symbol = {i: s for i, s in enumerate(_symbols)} result = '' for symbol_id in sequence: - if symbol_id in _ID_TO_SYMBOL: - s = _ID_TO_SYMBOL[symbol_id] + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: if len(s) > 1 and s[0] == '@': s = '{%s}' % s[1:] @@ -163,11 +168,11 @@ def _clean_text(text, cleaner_names): def _symbols_to_sequence(syms): - return [_SYMBOL_TO_ID[s] for s in syms if _should_keep_symbol(s)] + return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)] def _phoneme_to_sequence(phons): - return [_PHONEMES_TO_ID[s] for s in list(phons) if _should_keep_phoneme(s)] + return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)] def _arpabet_to_sequence(text): @@ -175,8 +180,8 @@ def _arpabet_to_sequence(text): def _should_keep_symbol(s): - return s in _SYMBOL_TO_ID and s not in ['~', '^', '_'] + return s in _symbol_to_id and s not in ['~', '^', '_'] def _should_keep_phoneme(p): - return p in _PHONEMES_TO_ID and p not in ['~', '^', '_'] + return p in _phonemes_to_id and p not in ['~', '^', '_'] diff --git a/utils/text/symbols.py b/utils/text/symbols.py index db83cb29..15862cbd 100644 --- a/utils/text/symbols.py +++ b/utils/text/symbols.py @@ -16,7 +16,7 @@ def make_symbols(characters, phnms, punctuations='!\'(),-.:;? ', pad='_', eos='~ _symbols = [pad, eos, bos] + list(characters) + _arpabet _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) - return symbols, phonemes + return _symbols, _phonemes _pad = '_' _eos = '~' @@ -34,7 +34,7 @@ _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' _diacrilics = 'ɚ˞ɫ' _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics -symbols, phonemes = make_symbols( _characters, _phonemes,_punctuations, _pad, _eos, _bos) +symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos) # Generate ALIEN language # from random import shuffle diff --git a/utils/visual.py b/utils/visual.py index 2f93d812..3b24364c 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -57,7 +57,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.text if 'text' in CONFIG.keys() else None) text = sequence_to_phoneme(seq, tp=CONFIG.text if 'text' in CONFIG.keys() else None) print(text) - + plt.yticks(range(len(text)), list(text)) plt.colorbar()