mirror of https://github.com/coqui-ai/TTS.git
fix travis lint check
This commit is contained in:
parent
59e2752107
commit
4e53896438
|
@ -195,7 +195,7 @@ class MyDataset(Dataset):
|
|||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
linear = [self.ap.spectrogram(w).astype('float32') for w in wav]
|
||||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [
|
||||
|
|
|
@ -144,7 +144,7 @@
|
|||
"\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"if 'text' in CONFIG.keys():\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.text)\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.text)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
|
||||
|
|
|
@ -151,7 +151,7 @@
|
|||
"\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"if 'text' in CONFIG.keys():\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.text)\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.text)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
|
||||
|
|
|
@ -112,7 +112,7 @@
|
|||
"\n",
|
||||
"# if the vocabulary was passed, replace the default\n",
|
||||
"if 'text' in CONFIG.keys():\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.text)\n",
|
||||
" symbols, phonemes = make_symbols(**CONFIG.text)\n",
|
||||
"\n",
|
||||
"# load the model\n",
|
||||
"num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n",
|
||||
|
|
|
@ -9,7 +9,10 @@ import yaml
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import load_config, setup_model
|
||||
from TTS.utils.speakers import load_speaker_mapping
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from TTS.utils.synthesis import *
|
||||
|
||||
from TTS.utils.text import make_symbols, phonemes, symbols
|
||||
|
||||
alphabets = r"([A-Za-z])"
|
||||
|
@ -38,18 +41,19 @@ class Synthesizer(object):
|
|||
self.config.pwgan_config, self.config.use_cuda)
|
||||
|
||||
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
||||
# pylint: disable=global-statement
|
||||
global symbols, phonemes
|
||||
|
||||
print(" > Loading TTS model ...")
|
||||
print(" | > model config: ", tts_config)
|
||||
print(" | > checkpoint file: ", tts_checkpoint)
|
||||
|
||||
|
||||
self.tts_config = load_config(tts_config)
|
||||
self.use_phonemes = self.tts_config.use_phonemes
|
||||
self.ap = AudioProcessor(**self.tts_config.audio)
|
||||
|
||||
if 'text' in self.tts_config.keys():
|
||||
symbols, phonemes = make_symbols(**self.tts_config.text)
|
||||
symbols, phonemes = make_symbols(**self.tts_config.text)
|
||||
|
||||
if self.use_phonemes:
|
||||
self.input_size = len(phonemes)
|
||||
|
@ -61,7 +65,7 @@ class Synthesizer(object):
|
|||
num_speakers = len(self.tts_speakers)
|
||||
else:
|
||||
num_speakers = 0
|
||||
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
||||
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
||||
# load model state
|
||||
cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
|
||||
# load the model
|
||||
|
@ -91,7 +95,7 @@ class Synthesizer(object):
|
|||
mulaw=self.wavernn_config.mulaw,
|
||||
pad=self.wavernn_config.pad,
|
||||
use_aux_net=self.wavernn_config.use_aux_net,
|
||||
use_upsample_net = self.wavernn_config.use_upsample_net,
|
||||
use_upsample_net=self.wavernn_config.use_upsample_net,
|
||||
upsample_factors=self.wavernn_config.upsample_factors,
|
||||
feat_dims=80,
|
||||
compute_dims=128,
|
||||
|
|
|
@ -109,7 +109,7 @@ if __name__ == "__main__":
|
|||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if 'text' in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.text)
|
||||
symbols, phonemes = make_symbols(**C.text)
|
||||
|
||||
# load speakers
|
||||
if args.speakers_json != '':
|
||||
|
|
|
@ -10,10 +10,13 @@ from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model
|
|||
|
||||
|
||||
class DemoServerTest(unittest.TestCase):
|
||||
# pylint: disable=R0201
|
||||
def _create_random_model(self):
|
||||
# pylint: disable=global-statement
|
||||
global symbols, phonemes
|
||||
config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json'))
|
||||
if 'text' in config.keys():
|
||||
symbols, phonemes = make_symbols(**config.text)
|
||||
symbols, phonemes = make_symbols(**config.text)
|
||||
|
||||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, 0, config)
|
||||
|
|
|
@ -37,7 +37,7 @@ class TestTTSDataset(unittest.TestCase):
|
|||
r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
meta_data=items,
|
||||
meta_data=items,
|
||||
tp=c.text if 'text' in c.keys() else None,
|
||||
batch_group_size=bgs,
|
||||
min_seq_len=c.min_seq_len,
|
||||
|
|
1
train.py
1
train.py
|
@ -516,6 +516,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
|
|
@ -8,11 +8,11 @@ from TTS.utils.text.symbols import make_symbols, symbols, phonemes, _phoneme_pun
|
|||
_eos
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
||||
_SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)}
|
||||
_ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)}
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
_PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)}
|
||||
_ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)}
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(phonemes)}
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
|
@ -57,21 +57,23 @@ def text2phone(text, language):
|
|||
|
||||
|
||||
def pad_with_eos_bos(phoneme_sequence, tp=None):
|
||||
global _PHONEMES_TO_ID, _bos, _eos
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id, _bos, _eos
|
||||
if tp:
|
||||
_bos = tp['bos']
|
||||
_eos = tp['eos']
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_PHONEMES_TO_ID = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
return [_PHONEMES_TO_ID[_bos]] + list(phoneme_sequence) + [_PHONEMES_TO_ID[_eos]]
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
|
||||
|
||||
|
||||
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None):
|
||||
global _PHONEMES_TO_ID
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id
|
||||
if tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_PHONEMES_TO_ID = {s: i for i, s in enumerate(_phonemes)}
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
sequence = []
|
||||
text = text.replace(":", "")
|
||||
|
@ -89,16 +91,17 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=
|
|||
|
||||
|
||||
def sequence_to_phoneme(sequence, tp=None):
|
||||
# pylint: disable=global-statement
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
global _ID_TO_PHONEMES
|
||||
global _id_to_phonemes
|
||||
result = ''
|
||||
if tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_ID_TO_PHONEMES = {i: s for i, s in enumerate(_phonemes)}
|
||||
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
||||
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _ID_TO_PHONEMES:
|
||||
s = _ID_TO_PHONEMES[symbol_id]
|
||||
if symbol_id in _id_to_phonemes:
|
||||
s = _id_to_phonemes[symbol_id]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
|
||||
|
@ -116,10 +119,11 @@ def text_to_sequence(text, cleaner_names, tp=None):
|
|||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
global _SYMBOL_TO_ID
|
||||
# pylint: disable=global-statement
|
||||
global _symbol_to_id
|
||||
if tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_SYMBOL_TO_ID = {s: i for i, s in enumerate(_symbols)}
|
||||
_symbol_to_id = {s: i for i, s in enumerate(_symbols)}
|
||||
|
||||
sequence = []
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
|
@ -137,15 +141,16 @@ def text_to_sequence(text, cleaner_names, tp=None):
|
|||
|
||||
def sequence_to_text(sequence, tp=None):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
global _ID_TO_SYMBOL
|
||||
# pylint: disable=global-statement
|
||||
global _id_to_symbol
|
||||
if tp:
|
||||
_symbols, _ = make_symbols(**tp)
|
||||
_ID_TO_SYMBOL = {i: s for i, s in enumerate(_symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
|
||||
result = ''
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _ID_TO_SYMBOL:
|
||||
s = _ID_TO_SYMBOL[symbol_id]
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
|
@ -163,11 +168,11 @@ def _clean_text(text, cleaner_names):
|
|||
|
||||
|
||||
def _symbols_to_sequence(syms):
|
||||
return [_SYMBOL_TO_ID[s] for s in syms if _should_keep_symbol(s)]
|
||||
return [_symbol_to_id[s] for s in syms if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _phoneme_to_sequence(phons):
|
||||
return [_PHONEMES_TO_ID[s] for s in list(phons) if _should_keep_phoneme(s)]
|
||||
return [_phonemes_to_id[s] for s in list(phons) if _should_keep_phoneme(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
|
@ -175,8 +180,8 @@ def _arpabet_to_sequence(text):
|
|||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _SYMBOL_TO_ID and s not in ['~', '^', '_']
|
||||
return s in _symbol_to_id and s not in ['~', '^', '_']
|
||||
|
||||
|
||||
def _should_keep_phoneme(p):
|
||||
return p in _PHONEMES_TO_ID and p not in ['~', '^', '_']
|
||||
return p in _phonemes_to_id and p not in ['~', '^', '_']
|
||||
|
|
|
@ -16,7 +16,7 @@ def make_symbols(characters, phnms, punctuations='!\'(),-.:;? ', pad='_', eos='~
|
|||
_symbols = [pad, eos, bos] + list(characters) + _arpabet
|
||||
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
|
||||
|
||||
return symbols, phonemes
|
||||
return _symbols, _phonemes
|
||||
|
||||
_pad = '_'
|
||||
_eos = '~'
|
||||
|
@ -34,7 +34,7 @@ _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
|
|||
_diacrilics = 'ɚ˞ɫ'
|
||||
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
|
||||
|
||||
symbols, phonemes = make_symbols( _characters, _phonemes,_punctuations, _pad, _eos, _bos)
|
||||
symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
|
||||
|
||||
# Generate ALIEN language
|
||||
# from random import shuffle
|
||||
|
|
|
@ -57,7 +57,7 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON
|
|||
seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.text if 'text' in CONFIG.keys() else None)
|
||||
text = sequence_to_phoneme(seq, tp=CONFIG.text if 'text' in CONFIG.keys() else None)
|
||||
print(text)
|
||||
|
||||
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
|
||||
|
|
Loading…
Reference in New Issue