diff --git a/layers/attention.py b/layers/attention.py index fc50a9dd..054f792c 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -170,9 +170,9 @@ class AttentionRNNCell(nn.Module): # Update the window self.win_idx = torch.argmax(alignment,1).long()[0].item() # Normalize context weight - # alignment = F.softmax(alignment, dim=-1) + alignment = F.softmax(alignment, dim=-1) # alignment = 5 * alignment - alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) + # alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) # Attention context vector # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j diff --git a/train.py b/train.py index cd969e9d..3e8bf9ed 100644 --- a/train.py +++ b/train.py @@ -223,10 +223,6 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, avg_stop_loss, epoch_time, avg_step_time), flush=True) - align_img = alignments[0].data.cpu().numpy() - align_img = plot_alignment(align_img) - align_img.savefig('/home/erogol/Desktop/alignment_{}.png'.format(current_step)) - # Plot Training Epoch Stats epoch_stats = {"loss_postnet": avg_linear_loss, "loss_decoder": avg_mel_loss, diff --git a/utils/text/__init__.py b/utils/text/__init__.py index 3008418b..d29ae71e 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -20,54 +20,6 @@ _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') pat = r'['+_punctuations[:-1]+']+' -def text2phone(text): - ''' - Convert graphemes to phonemes. - ''' - seperator = phonemizer.separator.Separator(' ', '', '|') - #try: - punctuations = re.findall(pat, text) - ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language='en-us') - # Replace \n with matching punctuations. - if len(punctuations) > 0: - for punct in punctuations[:-1]: - ph = ph.replace(' \n', punct+'| ', 1) - try: - ph = ph[:-1] + punctuations[-1] - except: - print(text) - return ph - - -def phoneme_to_sequence(text, cleaner_names): - ''' - TODO: This ignores punctuations - ''' - sequence = [] - clean_text = _clean_text(text, cleaner_names) - phonemes = text2phone(clean_text) -# print(phonemes.replace('|', '')) - if phonemes is None: - print("!! After phoneme conversion the result is None. -- {} ".format(clean_text)) - for phoneme in phonemes.split('|'): - # print(word, ' -- ', phonemes_text) - sequence += _phoneme_to_sequence(phoneme) - # Aeepnd EOS char - sequence.append(_phonemes_to_id['~']) - return sequence - - -def sequence_to_phoneme(sequence): - '''Converts a sequence of IDs back to a string''' - result = '' - for symbol_id in sequence: - if symbol_id in _id_to_phonemes: - s = _id_to_phonemes[symbol_id] - print(s) - result += s - return result.replace('}{', ' ') - - def text2phone(text, language): ''' Convert graphemes to phonemes.