mirror of https://github.com/coqui-ai/TTS.git
Bug fixes
This commit is contained in:
parent
af22bed149
commit
7e020d4084
|
@ -170,9 +170,9 @@ class AttentionRNNCell(nn.Module):
|
||||||
# Update the window
|
# Update the window
|
||||||
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
||||||
# Normalize context weight
|
# Normalize context weight
|
||||||
# alignment = F.softmax(alignment, dim=-1)
|
alignment = F.softmax(alignment, dim=-1)
|
||||||
# alignment = 5 * alignment
|
# alignment = 5 * alignment
|
||||||
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
|
# alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
|
||||||
# Attention context vector
|
# Attention context vector
|
||||||
# (batch, 1, dim)
|
# (batch, 1, dim)
|
||||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||||
|
|
4
train.py
4
train.py
|
@ -223,10 +223,6 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
||||||
avg_stop_loss, epoch_time, avg_step_time),
|
avg_stop_loss, epoch_time, avg_step_time),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
align_img = alignments[0].data.cpu().numpy()
|
|
||||||
align_img = plot_alignment(align_img)
|
|
||||||
align_img.savefig('/home/erogol/Desktop/alignment_{}.png'.format(current_step))
|
|
||||||
|
|
||||||
# Plot Training Epoch Stats
|
# Plot Training Epoch Stats
|
||||||
epoch_stats = {"loss_postnet": avg_linear_loss,
|
epoch_stats = {"loss_postnet": avg_linear_loss,
|
||||||
"loss_decoder": avg_mel_loss,
|
"loss_decoder": avg_mel_loss,
|
||||||
|
|
|
@ -20,54 +20,6 @@ _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||||
pat = r'['+_punctuations[:-1]+']+'
|
pat = r'['+_punctuations[:-1]+']+'
|
||||||
|
|
||||||
|
|
||||||
def text2phone(text):
|
|
||||||
'''
|
|
||||||
Convert graphemes to phonemes.
|
|
||||||
'''
|
|
||||||
seperator = phonemizer.separator.Separator(' ', '', '|')
|
|
||||||
#try:
|
|
||||||
punctuations = re.findall(pat, text)
|
|
||||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language='en-us')
|
|
||||||
# Replace \n with matching punctuations.
|
|
||||||
if len(punctuations) > 0:
|
|
||||||
for punct in punctuations[:-1]:
|
|
||||||
ph = ph.replace(' \n', punct+'| ', 1)
|
|
||||||
try:
|
|
||||||
ph = ph[:-1] + punctuations[-1]
|
|
||||||
except:
|
|
||||||
print(text)
|
|
||||||
return ph
|
|
||||||
|
|
||||||
|
|
||||||
def phoneme_to_sequence(text, cleaner_names):
|
|
||||||
'''
|
|
||||||
TODO: This ignores punctuations
|
|
||||||
'''
|
|
||||||
sequence = []
|
|
||||||
clean_text = _clean_text(text, cleaner_names)
|
|
||||||
phonemes = text2phone(clean_text)
|
|
||||||
# print(phonemes.replace('|', ''))
|
|
||||||
if phonemes is None:
|
|
||||||
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
|
|
||||||
for phoneme in phonemes.split('|'):
|
|
||||||
# print(word, ' -- ', phonemes_text)
|
|
||||||
sequence += _phoneme_to_sequence(phoneme)
|
|
||||||
# Aeepnd EOS char
|
|
||||||
sequence.append(_phonemes_to_id['~'])
|
|
||||||
return sequence
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_to_phoneme(sequence):
|
|
||||||
'''Converts a sequence of IDs back to a string'''
|
|
||||||
result = ''
|
|
||||||
for symbol_id in sequence:
|
|
||||||
if symbol_id in _id_to_phonemes:
|
|
||||||
s = _id_to_phonemes[symbol_id]
|
|
||||||
print(s)
|
|
||||||
result += s
|
|
||||||
return result.replace('}{', ' ')
|
|
||||||
|
|
||||||
|
|
||||||
def text2phone(text, language):
|
def text2phone(text, language):
|
||||||
'''
|
'''
|
||||||
Convert graphemes to phonemes.
|
Convert graphemes to phonemes.
|
||||||
|
|
Loading…
Reference in New Issue