remove attention mask

This commit is contained in:
Eren Golge 2018-03-19 08:26:16 -07:00
parent 1b9f07918e
commit 3071e7f6f6
4 changed files with 12 additions and 26 deletions

View File

@ -12,20 +12,20 @@
"text_cleaner": "english_cleaners", "text_cleaner": "english_cleaners",
"epochs": 2000, "epochs": 2000,
"lr": 0.00001875, "lr": 0.001,
"warmup_steps": 4000, "warmup_steps": 4000,
"batch_size": 2, "batch_size": 32,
"eval_batch_size": 32, "eval_batch_size": 32,
"r": 5, "r": 5,
"griffin_lim_iters": 60, "griffin_lim_iters": 60,
"power": 1.5, "power": 1.5,
"num_loader_workers": 16, "num_loader_workers": 12,
"checkpoint": false, "checkpoint": false,
"save_step": 69, "save_step": 69,
"data_path": "/run/shm/erogol/LJSpeech-1.0", "data_path": "/run/shm/erogol/LJSpeech-1.0",
"min_seq_len": 90, "min_seq_len": 0,
"output_path": "result" "output_path": "result"
} }

View File

@ -231,8 +231,8 @@ class Decoder(nn.Module):
# RNN_state -> |Linear| -> mel_spec # RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, memory_dim * r) self.proj_to_mel = nn.Linear(256, memory_dim * r)
def forward(self, inputs, memory=None, input_lengths=None): def forward(self, inputs, memory=None):
r""" """
Decoder forward step. Decoder forward step.
If decoder inputs are not given (e.g., at testing time), as noted in If decoder inputs are not given (e.g., at testing time), as noted in
@ -242,8 +242,6 @@ class Decoder(nn.Module):
inputs: Encoder outputs. inputs: Encoder outputs.
memory (None): Decoder memory (autoregression. If None (at eval-time), memory (None): Decoder memory (autoregression. If None (at eval-time),
decoder outputs are used as decoder inputs. decoder outputs are used as decoder inputs.
input_lengths (None): input lengths, used for
attention masking.
Shapes: Shapes:
- inputs: batch x time x encoder_out_dim - inputs: batch x time x encoder_out_dim
@ -251,12 +249,6 @@ class Decoder(nn.Module):
""" """
B = inputs.size(0) B = inputs.size(0)
# if input_lengths is not None:
# mask = get_mask_from_lengths(processed_inputs, input_lengths)
# else:
# mask = None
# Run greedy decoding if memory is None # Run greedy decoding if memory is None
greedy = memory is None greedy = memory is None

View File

@ -8,12 +8,11 @@ from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG
class Tacotron(nn.Module): class Tacotron(nn.Module):
def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80,
freq_dim=1025, r=5, padding_idx=None, freq_dim=1025, r=5, padding_idx=None):
use_atten_mask=False):
super(Tacotron, self).__init__() super(Tacotron, self).__init__()
self.mel_dim = mel_dim self.mel_dim = mel_dim
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.use_atten_mask = use_atten_mask
self.embedding = nn.Embedding(len(symbols), embedding_dim, self.embedding = nn.Embedding(len(symbols), embedding_dim,
padding_idx=padding_idx) padding_idx=padding_idx)
print(" | > Embedding dim : {}".format(len(symbols))) print(" | > Embedding dim : {}".format(len(symbols)))
@ -26,16 +25,13 @@ class Tacotron(nn.Module):
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
self.last_linear = nn.Linear(mel_dim * 2, freq_dim) self.last_linear = nn.Linear(mel_dim * 2, freq_dim)
def forward(self, characters, mel_specs=None, input_lengths=None): def forward(self, characters, mel_specs=None):
B = characters.size(0) B = characters.size(0)
inputs = self.embedding(characters) inputs = self.embedding(characters)
# (B, T', in_dim) # (B, T', in_dim)
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
if not self.use_atten_mask:
input_lengths = None
# (B, T', mel_dim*r) # (B, T', mel_dim*r)
mel_outputs, alignments = self.decoder( mel_outputs, alignments = self.decoder(
encoder_outputs, mel_specs, input_lengths=input_lengths) encoder_outputs, mel_specs, input_lengths=input_lengths)

View File

@ -112,8 +112,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
# forward pass # forward pass
mel_output, linear_output, alignments =\ mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_spec_var, model.forward(text_input_var, mel_spec_var)
input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths)))
# loss computation # loss computation
mel_loss = criterion(mel_output, mel_spec_var) mel_loss = criterion(mel_output, mel_spec_var)
@ -337,8 +336,7 @@ def main(args):
c.hidden_size, c.hidden_size,
c.num_mels, c.num_mels,
c.num_freq, c.num_freq,
c.r, c.r)
use_atten_mask=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)