mirror of https://github.com/coqui-ai/TTS.git
RNN stop-token prediction
This commit is contained in:
parent
a31e60e928
commit
40f1a3d3a5
|
@ -232,12 +232,7 @@ class Decoder(nn.Module):
|
|||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
self.stopnet = nn.Sequential(nn.Linear(256 + self.r * memory_dim, memory_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(memory_dim, 256 + self.r * memory_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256 + self.r * memory_dim, 1),
|
||||
nn.Sigmoid())
|
||||
self.stopnet = StopNet(r, memory_dim)
|
||||
|
||||
def forward(self, inputs, memory=None):
|
||||
"""
|
||||
|
@ -273,6 +268,7 @@ class Decoder(nn.Module):
|
|||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))]
|
||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
# Time first (T_decoder, B, memory_dim)
|
||||
if memory is not None:
|
||||
memory = memory.transpose(0, 1)
|
||||
|
@ -304,9 +300,9 @@ class Decoder(nn.Module):
|
|||
decoder_output = decoder_input
|
||||
# predict mel vectors from decoder vectors
|
||||
output = self.proj_to_mel(decoder_output)
|
||||
stop_input = torch.cat((output, decoder_output), -1)
|
||||
stop_input = output
|
||||
# predict stop token
|
||||
stop_token = self.stopnet(stop_input)
|
||||
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
||||
outputs += [output]
|
||||
alignments += [alignment]
|
||||
stop_tokens += [stop_token]
|
||||
|
@ -327,3 +323,21 @@ class Decoder(nn.Module):
|
|||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
return outputs, alignments, stop_tokens
|
||||
|
||||
|
||||
class StopNet(nn.Module):
|
||||
|
||||
def __init__(self, r, memory_dim):
|
||||
super(StopNet, self).__init__()
|
||||
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear = nn.Linear(r * memory_dim, 1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, inputs, rnn_hidden):
|
||||
rnn_hidden = self.rnn(inputs, rnn_hidden)
|
||||
outputs = self.relu(rnn_hidden)
|
||||
outputs = self.linear(outputs)
|
||||
outputs = self.sigmoid(outputs)
|
||||
return outputs, rnn_hidden
|
||||
|
5
train.py
5
train.py
|
@ -88,7 +88,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
# setup lr
|
||||
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||
current_lr_st = lr_decay(0.01, current_step, c.warmup_steps)
|
||||
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||
|
||||
for params_group in optimizer.param_groups:
|
||||
params_group['lr'] = current_lr
|
||||
|
@ -363,8 +363,7 @@ def main(args):
|
|||
c.r)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
optimizer_st = optim.SGD(model.decoder.stopnet.parameters(), lr=0.01,
|
||||
momentum=0.9, nesterov=True)
|
||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||
|
||||
criterion = L1LossMasked()
|
||||
criterion_st = nn.BCELoss()
|
||||
|
|
Loading…
Reference in New Issue