mirror of https://github.com/coqui-ai/TTS.git
Better naming of variables
This commit is contained in:
parent
9f5e102473
commit
2fd37a5bad
|
@ -15,7 +15,7 @@
|
|||
"lr": 0.003,
|
||||
"lr_patience": 5,
|
||||
"lr_decay": 0.5,
|
||||
"batch_size": 128,
|
||||
"batch_size": 98,
|
||||
"r": 5,
|
||||
|
||||
"griffin_lim_iters": 60,
|
||||
|
|
|
@ -11,11 +11,11 @@ class BahdanauAttention(nn.Module):
|
|||
self.tanh = nn.Tanh()
|
||||
self.v = nn.Linear(dim, 1, bias=False)
|
||||
|
||||
def forward(self, query, processed_memory):
|
||||
def forward(self, query, processed_inputs):
|
||||
"""
|
||||
Args:
|
||||
query: (batch, 1, dim) or (batch, dim)
|
||||
processed_memory: (batch, max_time, dim)
|
||||
processed_inputs: (batch, max_time, dim)
|
||||
"""
|
||||
if query.dim() == 2:
|
||||
# insert time-axis for broadcasting
|
||||
|
@ -24,7 +24,7 @@ class BahdanauAttention(nn.Module):
|
|||
processed_query = self.query_layer(query)
|
||||
|
||||
# (batch, max_time, 1)
|
||||
alignment = self.v(self.tanh(processed_query + processed_memory))
|
||||
alignment = self.v(self.tanh(processed_query + processed_inputs))
|
||||
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
@ -44,44 +44,48 @@ def get_mask_from_lengths(memory, memory_lengths):
|
|||
|
||||
|
||||
class AttentionWrapper(nn.Module):
|
||||
def __init__(self, rnn_cell, attention_mechanism,
|
||||
def __init__(self, rnn_cell, alignment_model,
|
||||
score_mask_value=-float("inf")):
|
||||
super(AttentionWrapper, self).__init__()
|
||||
self.rnn_cell = rnn_cell
|
||||
self.attention_mechanism = attention_mechanism
|
||||
self.alignment_model = alignment_model
|
||||
self.score_mask_value = score_mask_value
|
||||
|
||||
def forward(self, query, attention, cell_state, memory,
|
||||
processed_memory=None, mask=None, memory_lengths=None):
|
||||
if processed_memory is None:
|
||||
processed_memory = memory
|
||||
def forward(self, query, context_vec, cell_state, memory,
|
||||
processed_inputs=None, mask=None, memory_lengths=None):
|
||||
|
||||
if processed_inputs is None:
|
||||
processed_inputs = memory
|
||||
|
||||
if memory_lengths is not None and mask is None:
|
||||
mask = get_mask_from_lengths(memory, memory_lengths)
|
||||
|
||||
# Concat input query and previous attention context
|
||||
cell_input = torch.cat((query, attention), -1)
|
||||
# Concat input query and previous context_vec context
|
||||
import ipdb
|
||||
ipdb.set_trace()
|
||||
cell_input = torch.cat((query, context_vec), -1)
|
||||
|
||||
# Feed it to RNN
|
||||
cell_output = self.rnn_cell(cell_input, cell_state)
|
||||
|
||||
# Alignment
|
||||
# (batch, max_time)
|
||||
alignment = self.attention_mechanism(cell_output, processed_memory)
|
||||
alignment = self.alignment_model(cell_output, processed_inputs)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.view(query.size(0), -1)
|
||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
# Normalize attention weight
|
||||
# Normalize context_vec weight
|
||||
alignment = F.softmax(alignment, dim=-1)
|
||||
|
||||
# Attention context vector
|
||||
# (batch, 1, dim)
|
||||
attention = torch.bmm(alignment.unsqueeze(1), memory)
|
||||
context_vec = torch.bmm(alignment.unsqueeze(1), memory)
|
||||
|
||||
# (batch, dim)
|
||||
attention = attention.squeeze(1)
|
||||
context_vec = context_vec.squeeze(1)
|
||||
|
||||
return cell_output, context_vec, alignment
|
||||
|
||||
return cell_output, attention, alignment
|
||||
|
||||
|
|
|
@ -152,29 +152,27 @@ class Encoder(nn.Module):
|
|||
class Decoder(nn.Module):
|
||||
def __init__(self, memory_dim, r):
|
||||
super(Decoder, self).__init__()
|
||||
self.max_decoder_steps = 200
|
||||
self.memory_dim = memory_dim
|
||||
self.r = r
|
||||
# input -> |Linear| -> processed_inputs
|
||||
self.input_layer = nn.Linear(256, 256, bias=False)
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
|
||||
# attetion RNN
|
||||
# processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
||||
self.attention_rnn = AttentionWrapper(
|
||||
nn.GRUCell(256 + 128, 256),
|
||||
BahdanauAttention(256)
|
||||
)
|
||||
|
||||
self.memory_layer = nn.Linear(256, 256, bias=False)
|
||||
|
||||
# concat and project context and attention vectors
|
||||
# (prenet_out + attention context) -> output
|
||||
# (prenet_out | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(512, 256)
|
||||
|
||||
# decoder RNNs
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
self.decoder_rnns = nn.ModuleList(
|
||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
|
||||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
self.max_decoder_steps = 200
|
||||
|
||||
def forward(self, decoder_inputs, memory=None, memory_lengths=None):
|
||||
def forward(self, inputs, memory=None, memory_lengths=None):
|
||||
"""
|
||||
Decoder forward step.
|
||||
|
||||
|
@ -182,17 +180,18 @@ class Decoder(nn.Module):
|
|||
Tacotron paper, greedy decoding is adapted.
|
||||
|
||||
Args:
|
||||
decoder_inputs: Encoder outputs. (B, T_encoder, dim)
|
||||
inputs: Encoder outputs. (B, T_encoder, dim)
|
||||
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
|
||||
decoder outputs are used as decoder inputs.
|
||||
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
||||
attention masking.
|
||||
"""
|
||||
B = decoder_inputs.size(0)
|
||||
B = inputs.size(0)
|
||||
|
||||
processed_memory = self.memory_layer(decoder_inputs)
|
||||
# TODO: take thi segment into Attention module.
|
||||
processed_inputs = self.input_layer(inputs)
|
||||
if memory_lengths is not None:
|
||||
mask = get_mask_from_lengths(processed_memory, memory_lengths)
|
||||
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
|
@ -208,18 +207,18 @@ class Decoder(nn.Module):
|
|||
self.memory_dim, self.r)
|
||||
T_decoder = memory.size(1)
|
||||
|
||||
# go frames - 0 frames tarting the sequence
|
||||
initial_input = Variable(
|
||||
decoder_inputs.data.new(B, self.memory_dim * self.r).zero_())
|
||||
# go frame - 0 frames tarting the sequence
|
||||
initial_memory = Variable(
|
||||
inputs.data.new(B, self.memory_dim * self.r).zero_())
|
||||
|
||||
# Init decoder states
|
||||
attention_rnn_hidden = Variable(
|
||||
decoder_inputs.data.new(B, 256).zero_())
|
||||
inputs.data.new(B, 256).zero_())
|
||||
decoder_rnn_hiddens = [Variable(
|
||||
decoder_inputs.data.new(B, 256).zero_())
|
||||
inputs.data.new(B, 256).zero_())
|
||||
for _ in range(len(self.decoder_rnns))]
|
||||
current_attention = Variable(
|
||||
decoder_inputs.data.new(B, 256).zero_())
|
||||
current_context_vec = Variable(
|
||||
inputs.data.new(B, 256).zero_())
|
||||
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
if memory is not None:
|
||||
|
@ -229,21 +228,21 @@ class Decoder(nn.Module):
|
|||
alignments = []
|
||||
|
||||
t = 0
|
||||
current_input = initial_input
|
||||
memory_input = initial_memory
|
||||
while True:
|
||||
if t > 0:
|
||||
current_input = outputs[-1] if greedy else memory[t - 1]
|
||||
memory_input = outputs[-1] if greedy else memory[t - 1]
|
||||
# Prenet
|
||||
current_input = self.prenet(current_input)
|
||||
memory_input = self.prenet(memory_input)
|
||||
|
||||
# Attention RNN
|
||||
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
|
||||
current_input, current_attention, attention_rnn_hidden,
|
||||
decoder_inputs, processed_memory=processed_memory, mask=mask)
|
||||
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
||||
memory_input, current_context_vec, attention_rnn_hidden,
|
||||
inputs, processed_inputs=processed_inputs, mask=mask)
|
||||
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((attention_rnn_hidden, current_attention), -1))
|
||||
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
|
@ -266,7 +265,8 @@ class Decoder(nn.Module):
|
|||
if t > 1 and is_end_of_frames(output):
|
||||
break
|
||||
elif t > self.max_decoder_steps:
|
||||
print("Warning! doesn't seems to be converged")
|
||||
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||
Something is probably wrong.")
|
||||
break
|
||||
else:
|
||||
if t >= T_decoder:
|
||||
|
|
13
train.py
13
train.py
|
@ -124,7 +124,7 @@ def main(args):
|
|||
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
||||
progbar = Progbar(len(dataset) / c.batch_size)
|
||||
|
||||
for i, data in enumerate(dataloader):
|
||||
for num_iter, data in enumerate(dataloader):
|
||||
start_time = time.time()
|
||||
|
||||
text_input = data[0]
|
||||
|
@ -132,8 +132,7 @@ def main(args):
|
|||
magnitude_input = data[2]
|
||||
mel_input = data[3]
|
||||
|
||||
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
||||
print(current_step)
|
||||
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
|
||||
|
||||
# setup lr
|
||||
current_lr = lr_decay(c.lr, current_step)
|
||||
|
@ -190,10 +189,10 @@ def main(args):
|
|||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
progbar.update(i+1, values=[('total_loss', loss.data[0]),
|
||||
('linear_loss', linear_loss.data[0]),
|
||||
('mel_loss', mel_loss.data[0]),
|
||||
('grad_norm', grad_norm)])
|
||||
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||
('linear_loss', linear_loss.data[0]),
|
||||
('mel_loss', mel_loss.data[0]),
|
||||
('grad_norm', grad_norm)])
|
||||
|
||||
# Plot Learning Stats
|
||||
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
||||
|
|
Loading…
Reference in New Issue