plot attention alignments

This commit is contained in:
Eren Golge 2018-02-02 05:37:09 -08:00
parent 235ce071c6
commit 4e0ab65bbf
5 changed files with 66 additions and 55 deletions

Binary file not shown.

View File

@ -53,7 +53,7 @@ class LJSpeechDataset(Dataset):
text = [d['text'] for d in batch] text = [d['text'] for d in batch]
text_lenghts = [len(x) for x in text] text_lenghts = [len(x) for x in text]
max_text_len = np.max(text_lengths) max_text_len = np.max(text_lenghts)
wav = [d['wav'] for d in batch] wav = [d['wav'] for d in batch]
# PAD sequences with largest length of the batch # PAD sequences with largest length of the batch

View File

@ -55,6 +55,7 @@ class AttentionWrapper(nn.Module):
processed_memory=None, mask=None, memory_lengths=None): processed_memory=None, mask=None, memory_lengths=None):
if processed_memory is None: if processed_memory is None:
processed_memory = memory processed_memory = memory
if memory_lengths is not None and mask is None: if memory_lengths is not None and mask is None:
mask = get_mask_from_lengths(memory, memory_lengths) mask = get_mask_from_lengths(memory, memory_lengths)
@ -73,7 +74,7 @@ class AttentionWrapper(nn.Module):
alignment.data.masked_fill_(mask, self.score_mask_value) alignment.data.masked_fill_(mask, self.score_mask_value)
# Normalize attention weight # Normalize attention weight
alignment = F.softmax(alignment, dim=-1) ## TODO: might be buggy alignment = F.softmax(alignment, dim=-1)
# Attention context vector # Attention context vector
# (batch, 1, dim) # (batch, 1, dim)

File diff suppressed because one or more lines are too long

View File

@ -123,7 +123,7 @@ def main(args):
# setup lr # setup lr
current_lr = lr_decay(c.lr, current_step) current_lr = lr_decay(c.lr, current_step)
for params_group in optimizer.param_groups: for params_group in optimizer.param_groups:
param_group['lr'] = current_lr params_group['lr'] = current_lr
optimizer.zero_grad() optimizer.zero_grad()
@ -204,11 +204,14 @@ def main(args):
checkpoint_path) checkpoint_path)
print("\n | > Checkpoint is saved : {}".format(checkpoint_path)) print("\n | > Checkpoint is saved : {}".format(checkpoint_path))
# Log spectrogram reconstruction # Diagnostic visualizations
const_spec = linear_output[0].data.cpu()[None, :] const_spec = linear_output[0].data.cpu()[None, :]
gt_spec = linear_spec_var[0].data.cpu()[None, :] gt_spec = linear_spec_var[0].data.cpu()[None, :]
align_img = alignments[0].data.cpu().t()[None, :]
tb.add_image('Spec/Reconstruction', const_spec, current_step) tb.add_image('Spec/Reconstruction', const_spec, current_step)
tb.add_image('Spec/GroundTruth', gt_spec, current_step) tb.add_image('Spec/GroundTruth', gt_spec, current_step)
tb.add_image('Attn/Alignment', align_img, current_step)
#lr_scheduler.step(loss.data[0]) #lr_scheduler.step(loss.data[0])
tb.add_scalar('Time/EpochTime', epoch_time, epoch) tb.add_scalar('Time/EpochTime', epoch_time, epoch)