tacotrongst test + test fixes

This commit is contained in:
SanjaESC 2020-07-12 14:07:44 +02:00 committed by thllwg
parent 564fc0aab4
commit 6d3ddae64e
5 changed files with 79 additions and 24 deletions

View File

@ -93,14 +93,14 @@ class Tacotron2(TacotronAbstract):
if self.num_speakers > 1:
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
if hasattr(self, 'gst'):
if self.gst:
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
else:
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
else:
if hasattr(self, 'gst'):
if self.gst:
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, mel_specs)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
@ -138,14 +138,14 @@ class Tacotron2(TacotronAbstract):
if self.num_speakers > 1:
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
if hasattr(self, 'gst'):
if self.gst:
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
else:
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
else:
if hasattr(self, 'gst'):
if self.gst:
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)
@ -168,14 +168,14 @@ class Tacotron2(TacotronAbstract):
if self.num_speakers > 1:
embedded_speakers = self.speaker_embedding(speaker_ids)[:, None]
embedded_speakers = embedded_speakers.repeat(1, encoder_outputs.size(1), 1)
if hasattr(self, 'gst'):
if self.gst:
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst, embedded_speakers], dim=-1)
else:
encoder_outputs = torch.cat([encoder_outputs, embedded_speakers], dim=-1)
else:
if hasattr(self, 'gst'):
if self.gst:
# B x gst_dim
encoder_outputs, embedded_gst = self.compute_gst(encoder_outputs, style_mel)
encoder_outputs = torch.cat([encoder_outputs, embedded_gst], dim=-1)

View File

@ -83,6 +83,14 @@
"use_phonemes": false, // use phonemes instead of raw characters. It is suggested for better pronounciation.
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
"text_cleaner": "phoneme_cleaners",
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
"use_speaker_embedding": false, // whether to use additional embeddings for separate speakers
"use_gst": false,
"gst": {
"gst_style_input": null,
"gst_embedding_dim": 256,
"gst_num_heads": 4,
"gst_style_tokens": 10
}
}

View File

@ -51,14 +51,5 @@
"output_path": "result",
"min_seq_len": 0,
"max_seq_len": 300,
"log_dir": "tests/outputs/",
"use_speaker_embedding": false,
"use_gst": false,
"gst": {
"gst_style_input": null,
"gst_embedding_dim": 512,
"gst_num_heads": 4,
"gst_style_tokens": 10
}
"log_dir": "tests/outputs/"
}

View File

@ -22,7 +22,7 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTrainTest(unittest.TestCase):
def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device)
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
@ -35,7 +35,7 @@ class TacotronTrainTest(unittest.TestCase):
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0],
stop_targets = stop_targets.view(input_dummy.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
@ -52,7 +52,63 @@ class TacotronTrainTest(unittest.TestCase):
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5):
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
input, input_lengths, mel_spec, mel_lengths, speaker_ids)
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets)
loss = loss + criterion(mel_postnet_out, mel_postnet_spec, mel_lengths) + stop_loss
loss.backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
# ignore pre-higway layer since it works conditional
# if count not in [145, 59]:
assert (param != param_ref).any(
), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref)
count += 1
class TacotronGSTTrainTest(unittest.TestCase):
def test_train_step(self):
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.randint(100, 128, (8, )).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
mel_lengths[0] = 30
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input_dummy.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
criterion = MSELossMasked(seq_len_norm=False).to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron2(num_chars=24,
gst=True,
r=c.r,
num_speakers=5).to(device)
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(),
model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5):
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
optimizer.zero_grad()

View File

@ -31,7 +31,7 @@ class TacotronTrainTest(unittest.TestCase):
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device)
linear_spec = torch.rand(8, 30, c.audio['fft_size']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
@ -49,7 +49,7 @@ class TacotronTrainTest(unittest.TestCase):
model = Tacotron(
num_chars=32,
num_speakers=5,
postnet_output_dim=c.audio['num_freq'],
postnet_output_dim=c.audio['fft_size'],
decoder_output_dim=c.audio['num_mels'],
r=c.r,
memory_size=c.memory_size
@ -93,7 +93,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
input_lengths = torch.randint(100, 129, (8, )).long().to(device)
input_lengths[-1] = 128
mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device)
linear_spec = torch.rand(8, 120, c.audio['fft_size']).to(device)
mel_lengths = torch.randint(20, 120, (8, )).long().to(device)
mel_lengths[-1] = 120
stop_targets = torch.zeros(8, 120, 1).float().to(device)
@ -113,7 +113,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
num_chars=32,
num_speakers=5,
gst=True,
postnet_output_dim=c.audio['num_freq'],
postnet_output_dim=c.audio['fft_size'],
decoder_output_dim=c.audio['num_mels'],
r=c.r,
memory_size=c.memory_size