diff --git a/models/tacotron2.py b/models/tacotron2.py index 23a40d4f..75ae9bef 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -93,14 +93,14 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs) encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) else: encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs) encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) @@ -138,14 +138,14 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) else: encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) @@ -168,14 +168,14 @@ class Tacotron2(TacotronAbstract): if self.num_speakers > 1: embedded_speakers = self.speaker_embedding(speaker_ids)[:, None] embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1) - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1) else: encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1) else: - if hasattr(self, 'gst'): + if self.gst: # B x gst_dim encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel) encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1) diff --git a/tests/outputs/dummy_model_config.json b/tests/outputs/dummy_model_config.json index 36fac3e5..c35c7495 100644 --- a/tests/outputs/dummy_model_config.json +++ b/tests/outputs/dummy_model_config.json @@ -83,6 +83,14 @@ "use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation. "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages "text_cleaner": "phoneme_cleaners", - "use_speaker_embedding": false // whether to use additional embeddings for separate speakers + "use_speaker_embedding": false, // whether to use additional embeddings for separate speakers + "use_gst": false, + "gst": { + "gst_style_input": null, + "gst_embedding_dim": 256, + "gst_num_heads": 4, + "gst_style_tokens": 10 + } } + diff --git a/tests/test_config.json b/tests/test_config.json index 450cb23a..31c2cd87 100644 --- a/tests/test_config.json +++ b/tests/test_config.json @@ -51,14 +51,5 @@ "output_path": "result", "min_seq_len": 0, "max_seq_len": 300, - "log_dir": "tests/outputs/", - - "use_speaker_embedding": false, - "use_gst": false, - "gst": { - "gst_style_input": null, - "gst_embedding_dim": 512, - "gst_num_heads": 4, - "gst_style_tokens": 10 - } + "log_dir": "tests/outputs/" } diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index ae9f20a2..5dfd7759 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -22,7 +22,7 @@ c = load_config(os.path.join(file_path, 'test_config.json')) class TacotronTrainTest(unittest.TestCase): def test_train_step(self): - input = torch.randint(0, 24, (8, 128)).long().to(device) + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 128, (8, )).long().to(device) input_lengths = torch.sort(input_lengths, descending=True)[0] mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) @@ -35,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase): for idx in mel_lengths: stop_targets[:, int(idx.item()):, 0] = 1.0 - stop_targets = stop_targets.view(input.shape[0], + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() @@ -52,7 +52,63 @@ class TacotronTrainTest(unittest.TestCase): optimizer = optim.Adam(model.parameters(), lr=c.lr) for i in range(5): mel_out, mel_postnet_out, align, stop_tokens = model.forward( - input, input_lengths, mel_spec, mel_lengths, speaker_ids) + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) + assert torch.sigmoid(stop_tokens).data.max() <= 1.0 + assert torch.sigmoid(stop_tokens).data.min() >= 0.0 + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 + + +class TacotronGSTTrainTest(unittest.TestCase): + def test_train_step(self): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + mel_lengths[0] = 30 + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input_dummy.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + criterion = MSELossMasked(seq_len_norm=False).to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron2(num_chars=24, + gst=True, + r=c.r, + num_speakers=5).to(device) + model.train() + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for i in range(5): + mel_out, mel_postnet_out, align, stop_tokens = model.forward( + input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids) assert torch.sigmoid(stop_tokens).data.max() <= 1.0 assert torch.sigmoid(stop_tokens).data.min() >= 0.0 optimizer.zero_grad() diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index 2bbb3c8d..00cc38df 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -31,7 +31,7 @@ class TacotronTrainTest(unittest.TestCase): input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) + linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device) speaker_ids = torch.randint(0, 5, (8, )).long().to(device) @@ -49,7 +49,7 @@ class TacotronTrainTest(unittest.TestCase): model = Tacotron( num_chars=32, num_speakers=5, - postnet_output_dim=c.audio['num_freq'], + postnet_output_dim=c.audio['fft_size'], decoder_output_dim=c.audio['num_mels'], r=c.r, memory_size=c.memory_size @@ -93,7 +93,7 @@ class TacotronGSTTrainTest(unittest.TestCase): input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device) + linear_spec = torch.rand(8, 120, c.audio['fft_size']).to(device) mel_lengths = torch.randint(20, 120, (8, )).long().to(device) mel_lengths[-1] = 120 stop_targets = torch.zeros(8, 120, 1).float().to(device) @@ -113,7 +113,7 @@ class TacotronGSTTrainTest(unittest.TestCase): num_chars=32, num_speakers=5, gst=True, - postnet_output_dim=c.audio['num_freq'], + postnet_output_dim=c.audio['fft_size'], decoder_output_dim=c.audio['num_mels'], r=c.r, memory_size=c.memory_size