mirror of https://github.com/coqui-ai/TTS.git
Better naming of variables
This commit is contained in:
parent
9f5e102473
commit
2fd37a5bad
|
@ -15,7 +15,7 @@
|
||||||
"lr": 0.003,
|
"lr": 0.003,
|
||||||
"lr_patience": 5,
|
"lr_patience": 5,
|
||||||
"lr_decay": 0.5,
|
"lr_decay": 0.5,
|
||||||
"batch_size": 128,
|
"batch_size": 98,
|
||||||
"r": 5,
|
"r": 5,
|
||||||
|
|
||||||
"griffin_lim_iters": 60,
|
"griffin_lim_iters": 60,
|
||||||
|
|
|
@ -11,11 +11,11 @@ class BahdanauAttention(nn.Module):
|
||||||
self.tanh = nn.Tanh()
|
self.tanh = nn.Tanh()
|
||||||
self.v = nn.Linear(dim, 1, bias=False)
|
self.v = nn.Linear(dim, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, query, processed_memory):
|
def forward(self, query, processed_inputs):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
query: (batch, 1, dim) or (batch, dim)
|
query: (batch, 1, dim) or (batch, dim)
|
||||||
processed_memory: (batch, max_time, dim)
|
processed_inputs: (batch, max_time, dim)
|
||||||
"""
|
"""
|
||||||
if query.dim() == 2:
|
if query.dim() == 2:
|
||||||
# insert time-axis for broadcasting
|
# insert time-axis for broadcasting
|
||||||
|
@ -24,7 +24,7 @@ class BahdanauAttention(nn.Module):
|
||||||
processed_query = self.query_layer(query)
|
processed_query = self.query_layer(query)
|
||||||
|
|
||||||
# (batch, max_time, 1)
|
# (batch, max_time, 1)
|
||||||
alignment = self.v(self.tanh(processed_query + processed_memory))
|
alignment = self.v(self.tanh(processed_query + processed_inputs))
|
||||||
|
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
return alignment.squeeze(-1)
|
return alignment.squeeze(-1)
|
||||||
|
@ -44,44 +44,48 @@ def get_mask_from_lengths(memory, memory_lengths):
|
||||||
|
|
||||||
|
|
||||||
class AttentionWrapper(nn.Module):
|
class AttentionWrapper(nn.Module):
|
||||||
def __init__(self, rnn_cell, attention_mechanism,
|
def __init__(self, rnn_cell, alignment_model,
|
||||||
score_mask_value=-float("inf")):
|
score_mask_value=-float("inf")):
|
||||||
super(AttentionWrapper, self).__init__()
|
super(AttentionWrapper, self).__init__()
|
||||||
self.rnn_cell = rnn_cell
|
self.rnn_cell = rnn_cell
|
||||||
self.attention_mechanism = attention_mechanism
|
self.alignment_model = alignment_model
|
||||||
self.score_mask_value = score_mask_value
|
self.score_mask_value = score_mask_value
|
||||||
|
|
||||||
def forward(self, query, attention, cell_state, memory,
|
def forward(self, query, context_vec, cell_state, memory,
|
||||||
processed_memory=None, mask=None, memory_lengths=None):
|
processed_inputs=None, mask=None, memory_lengths=None):
|
||||||
if processed_memory is None:
|
|
||||||
processed_memory = memory
|
if processed_inputs is None:
|
||||||
|
processed_inputs = memory
|
||||||
|
|
||||||
if memory_lengths is not None and mask is None:
|
if memory_lengths is not None and mask is None:
|
||||||
mask = get_mask_from_lengths(memory, memory_lengths)
|
mask = get_mask_from_lengths(memory, memory_lengths)
|
||||||
|
|
||||||
# Concat input query and previous attention context
|
# Concat input query and previous context_vec context
|
||||||
cell_input = torch.cat((query, attention), -1)
|
import ipdb
|
||||||
|
ipdb.set_trace()
|
||||||
|
cell_input = torch.cat((query, context_vec), -1)
|
||||||
|
|
||||||
# Feed it to RNN
|
# Feed it to RNN
|
||||||
cell_output = self.rnn_cell(cell_input, cell_state)
|
cell_output = self.rnn_cell(cell_input, cell_state)
|
||||||
|
|
||||||
# Alignment
|
# Alignment
|
||||||
# (batch, max_time)
|
# (batch, max_time)
|
||||||
alignment = self.attention_mechanism(cell_output, processed_memory)
|
alignment = self.alignment_model(cell_output, processed_inputs)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.view(query.size(0), -1)
|
mask = mask.view(query.size(0), -1)
|
||||||
alignment.data.masked_fill_(mask, self.score_mask_value)
|
alignment.data.masked_fill_(mask, self.score_mask_value)
|
||||||
|
|
||||||
# Normalize attention weight
|
# Normalize context_vec weight
|
||||||
alignment = F.softmax(alignment, dim=-1)
|
alignment = F.softmax(alignment, dim=-1)
|
||||||
|
|
||||||
# Attention context vector
|
# Attention context vector
|
||||||
# (batch, 1, dim)
|
# (batch, 1, dim)
|
||||||
attention = torch.bmm(alignment.unsqueeze(1), memory)
|
context_vec = torch.bmm(alignment.unsqueeze(1), memory)
|
||||||
|
|
||||||
# (batch, dim)
|
# (batch, dim)
|
||||||
attention = attention.squeeze(1)
|
context_vec = context_vec.squeeze(1)
|
||||||
|
|
||||||
|
return cell_output, context_vec, alignment
|
||||||
|
|
||||||
return cell_output, attention, alignment
|
|
||||||
|
|
||||||
|
|
|
@ -152,29 +152,27 @@ class Encoder(nn.Module):
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, memory_dim, r):
|
def __init__(self, memory_dim, r):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
|
self.max_decoder_steps = 200
|
||||||
self.memory_dim = memory_dim
|
self.memory_dim = memory_dim
|
||||||
self.r = r
|
self.r = r
|
||||||
|
# input -> |Linear| -> processed_inputs
|
||||||
|
self.input_layer = nn.Linear(256, 256, bias=False)
|
||||||
|
# memory -> |Prenet| -> processed_memory
|
||||||
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
|
self.prenet = Prenet(memory_dim * r, sizes=[256, 128])
|
||||||
# attetion RNN
|
# processed_inputs, prrocessed_memory -> |Attention| -> Attention, Alignment, RNN_State
|
||||||
self.attention_rnn = AttentionWrapper(
|
self.attention_rnn = AttentionWrapper(
|
||||||
nn.GRUCell(256 + 128, 256),
|
nn.GRUCell(256 + 128, 256),
|
||||||
BahdanauAttention(256)
|
BahdanauAttention(256)
|
||||||
)
|
)
|
||||||
|
# (prenet_out | attention context) -> |Linear| -> decoder_RNN_input
|
||||||
self.memory_layer = nn.Linear(256, 256, bias=False)
|
|
||||||
|
|
||||||
# concat and project context and attention vectors
|
|
||||||
# (prenet_out + attention context) -> output
|
|
||||||
self.project_to_decoder_in = nn.Linear(512, 256)
|
self.project_to_decoder_in = nn.Linear(512, 256)
|
||||||
|
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||||
# decoder RNNs
|
|
||||||
self.decoder_rnns = nn.ModuleList(
|
self.decoder_rnns = nn.ModuleList(
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
self.max_decoder_steps = 200
|
|
||||||
|
|
||||||
def forward(self, decoder_inputs, memory=None, memory_lengths=None):
|
def forward(self, inputs, memory=None, memory_lengths=None):
|
||||||
"""
|
"""
|
||||||
Decoder forward step.
|
Decoder forward step.
|
||||||
|
|
||||||
|
@ -182,17 +180,18 @@ class Decoder(nn.Module):
|
||||||
Tacotron paper, greedy decoding is adapted.
|
Tacotron paper, greedy decoding is adapted.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
decoder_inputs: Encoder outputs. (B, T_encoder, dim)
|
inputs: Encoder outputs. (B, T_encoder, dim)
|
||||||
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
|
memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time),
|
||||||
decoder outputs are used as decoder inputs.
|
decoder outputs are used as decoder inputs.
|
||||||
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
memory_lengths: Encoder output (memory) lengths. If not None, used for
|
||||||
attention masking.
|
attention masking.
|
||||||
"""
|
"""
|
||||||
B = decoder_inputs.size(0)
|
B = inputs.size(0)
|
||||||
|
|
||||||
processed_memory = self.memory_layer(decoder_inputs)
|
# TODO: take thi segment into Attention module.
|
||||||
|
processed_inputs = self.input_layer(inputs)
|
||||||
if memory_lengths is not None:
|
if memory_lengths is not None:
|
||||||
mask = get_mask_from_lengths(processed_memory, memory_lengths)
|
mask = get_mask_from_lengths(processed_inputs, memory_lengths)
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
|
@ -208,18 +207,18 @@ class Decoder(nn.Module):
|
||||||
self.memory_dim, self.r)
|
self.memory_dim, self.r)
|
||||||
T_decoder = memory.size(1)
|
T_decoder = memory.size(1)
|
||||||
|
|
||||||
# go frames - 0 frames tarting the sequence
|
# go frame - 0 frames tarting the sequence
|
||||||
initial_input = Variable(
|
initial_memory = Variable(
|
||||||
decoder_inputs.data.new(B, self.memory_dim * self.r).zero_())
|
inputs.data.new(B, self.memory_dim * self.r).zero_())
|
||||||
|
|
||||||
# Init decoder states
|
# Init decoder states
|
||||||
attention_rnn_hidden = Variable(
|
attention_rnn_hidden = Variable(
|
||||||
decoder_inputs.data.new(B, 256).zero_())
|
inputs.data.new(B, 256).zero_())
|
||||||
decoder_rnn_hiddens = [Variable(
|
decoder_rnn_hiddens = [Variable(
|
||||||
decoder_inputs.data.new(B, 256).zero_())
|
inputs.data.new(B, 256).zero_())
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_attention = Variable(
|
current_context_vec = Variable(
|
||||||
decoder_inputs.data.new(B, 256).zero_())
|
inputs.data.new(B, 256).zero_())
|
||||||
|
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
|
@ -229,21 +228,21 @@ class Decoder(nn.Module):
|
||||||
alignments = []
|
alignments = []
|
||||||
|
|
||||||
t = 0
|
t = 0
|
||||||
current_input = initial_input
|
memory_input = initial_memory
|
||||||
while True:
|
while True:
|
||||||
if t > 0:
|
if t > 0:
|
||||||
current_input = outputs[-1] if greedy else memory[t - 1]
|
memory_input = outputs[-1] if greedy else memory[t - 1]
|
||||||
# Prenet
|
# Prenet
|
||||||
current_input = self.prenet(current_input)
|
memory_input = self.prenet(memory_input)
|
||||||
|
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
attention_rnn_hidden, current_attention, alignment = self.attention_rnn(
|
attention_rnn_hidden, current_context_vec, alignment = self.attention_rnn(
|
||||||
current_input, current_attention, attention_rnn_hidden,
|
memory_input, current_context_vec, attention_rnn_hidden,
|
||||||
decoder_inputs, processed_memory=processed_memory, mask=mask)
|
inputs, processed_inputs=processed_inputs, mask=mask)
|
||||||
|
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((attention_rnn_hidden, current_attention), -1))
|
torch.cat((attention_rnn_hidden, current_context_vec), -1))
|
||||||
|
|
||||||
# Pass through the decoder RNNs
|
# Pass through the decoder RNNs
|
||||||
for idx in range(len(self.decoder_rnns)):
|
for idx in range(len(self.decoder_rnns)):
|
||||||
|
@ -266,7 +265,8 @@ class Decoder(nn.Module):
|
||||||
if t > 1 and is_end_of_frames(output):
|
if t > 1 and is_end_of_frames(output):
|
||||||
break
|
break
|
||||||
elif t > self.max_decoder_steps:
|
elif t > self.max_decoder_steps:
|
||||||
print("Warning! doesn't seems to be converged")
|
print(" !! Decoder stopped with 'max_decoder_steps'. \
|
||||||
|
Something is probably wrong.")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
|
|
13
train.py
13
train.py
|
@ -124,7 +124,7 @@ def main(args):
|
||||||
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
||||||
progbar = Progbar(len(dataset) / c.batch_size)
|
progbar = Progbar(len(dataset) / c.batch_size)
|
||||||
|
|
||||||
for i, data in enumerate(dataloader):
|
for num_iter, data in enumerate(dataloader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
text_input = data[0]
|
text_input = data[0]
|
||||||
|
@ -132,8 +132,7 @@ def main(args):
|
||||||
magnitude_input = data[2]
|
magnitude_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
|
|
||||||
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
current_step = num_iter + args.restore_step + epoch * len(dataloader) + 1
|
||||||
print(current_step)
|
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
current_lr = lr_decay(c.lr, current_step)
|
current_lr = lr_decay(c.lr, current_step)
|
||||||
|
@ -190,10 +189,10 @@ def main(args):
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
epoch_time += step_time
|
epoch_time += step_time
|
||||||
|
|
||||||
progbar.update(i+1, values=[('total_loss', loss.data[0]),
|
progbar.update(num_iter+1, values=[('total_loss', loss.data[0]),
|
||||||
('linear_loss', linear_loss.data[0]),
|
('linear_loss', linear_loss.data[0]),
|
||||||
('mel_loss', mel_loss.data[0]),
|
('mel_loss', mel_loss.data[0]),
|
||||||
('grad_norm', grad_norm)])
|
('grad_norm', grad_norm)])
|
||||||
|
|
||||||
# Plot Learning Stats
|
# Plot Learning Stats
|
||||||
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
||||||
|
|
Loading…
Reference in New Issue