From 5629292bde3fd5bf7c81d242274bc2c8aa07c6e5 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 16 Aug 2019 15:08:04 +0200 Subject: [PATCH] bug fixes --- layers/tacotron.py | 8 ++++---- tests/symbols_tests.py | 4 +++- tests/test_loader.py | 8 +++++--- utils/text/symbols.py | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 09c7e923..b5b6e132 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -364,13 +364,13 @@ class Decoder(nn.Module): processed_memory = self.prenet(self.memory_input) # Attention RNN self.attention_rnn_hidden = self.attention_rnn( - torch.cat((processed_memory, self.current_context_vec), -1), + torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden) - self.context_vec = self.attention_layer( + self.context_vec = self.attention( self.attention_rnn_hidden, inputs, self.processed_inputs, mask) # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( - torch.cat((self.query, self.context_vec), -1)) + torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) # Pass through the decoder RNNs for idx in range(len(self.decoder_rnns)): @@ -390,7 +390,7 @@ class Decoder(nn.Module): else: stop_token = self.stopnet(stopnet_input) output = output[:, : self.r * self.memory_dim] - return output, stop_token, self.attention_layer.attention_weights + return output, stop_token, self.attention.attention_weights def _update_memory_input(self, new_memory): if self.use_memory_queue: diff --git a/tests/symbols_tests.py b/tests/symbols_tests.py index 68c909c5..9bec0f18 100644 --- a/tests/symbols_tests.py +++ b/tests/symbols_tests.py @@ -1,7 +1,9 @@ import unittest from utils.text import phonemes +from collections import Counter class SymbolsTest(unittest.TestCase): def test_uniqueness(self): - assert sorted(phonemes) == sorted(list(set(phonemes))) + assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes))) + \ No newline at end of file diff --git a/tests/test_loader.py b/tests/test_loader.py index 92d6f7e2..4051c463 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,6 +1,7 @@ import os import unittest import shutil +import torch from torch.utils.data import DataLoader from utils.generic_utils import load_config @@ -130,10 +131,11 @@ class TestTTSDataset(unittest.TestCase): # check mel_spec consistency wav = self.ap.load_wav(item_idx[0]) mel = self.ap.melspectrogram(wav) - mel_dl = mel_input[0].cpu().numpy() - assert (abs(mel.T).astype("float32") + mel = torch.FloatTensor(mel) + mel_dl = mel_input[0] + assert (abs(mel.T) - abs(mel_dl[:-1]) - ).sum() == 0 + ).sum() == 0, (abs(mel.T)- abs(mel_dl[:-1])).sum() # check mel-spec correctness mel_spec = mel_input[0].cpu().numpy() diff --git a/utils/text/symbols.py b/utils/text/symbols.py index 9b7a36b4..ee6fd2cf 100644 --- a/utils/text/symbols.py +++ b/utils/text/symbols.py @@ -18,7 +18,7 @@ _vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ' _non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ' _pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ' _suprasegmentals = 'ˈˌːˑ' -_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ ' +_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' _diacrilics = 'ɚ˞ɫ' _phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))