mirror of https://github.com/coqui-ai/TTS.git
RNN stop-token prediction
This commit is contained in:
parent
7d5fe41ab4
commit
7a64c6e383
|
@ -232,12 +232,7 @@ class Decoder(nn.Module):
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
self.stopnet = nn.Sequential(nn.Linear(256 + self.r * memory_dim, memory_dim),
|
self.stopnet = StopNet(r, memory_dim)
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(memory_dim, 256 + self.r * memory_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(256 + self.r * memory_dim, 1),
|
|
||||||
nn.Sigmoid())
|
|
||||||
|
|
||||||
def forward(self, inputs, memory=None):
|
def forward(self, inputs, memory=None):
|
||||||
"""
|
"""
|
||||||
|
@ -273,6 +268,7 @@ class Decoder(nn.Module):
|
||||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||||
for _ in range(len(self.decoder_rnns))]
|
for _ in range(len(self.decoder_rnns))]
|
||||||
current_context_vec = inputs.data.new(B, 256).zero_()
|
current_context_vec = inputs.data.new(B, 256).zero_()
|
||||||
|
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||||
# Time first (T_decoder, B, memory_dim)
|
# Time first (T_decoder, B, memory_dim)
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
|
@ -304,9 +300,9 @@ class Decoder(nn.Module):
|
||||||
decoder_output = decoder_input
|
decoder_output = decoder_input
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(decoder_output)
|
output = self.proj_to_mel(decoder_output)
|
||||||
stop_input = torch.cat((output, decoder_output), -1)
|
stop_input = output
|
||||||
# predict stop token
|
# predict stop token
|
||||||
stop_token = self.stopnet(stop_input)
|
stop_token, stopnet_rnn_hidden = self.stopnet(stop_input, stopnet_rnn_hidden)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
stop_tokens += [stop_token]
|
stop_tokens += [stop_token]
|
||||||
|
@ -327,3 +323,21 @@ class Decoder(nn.Module):
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
return outputs, alignments, stop_tokens
|
return outputs, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class StopNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, r, memory_dim):
|
||||||
|
super(StopNet, self).__init__()
|
||||||
|
self.rnn = nn.GRUCell(memory_dim * r, memory_dim * r)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.linear = nn.Linear(r * memory_dim, 1)
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, inputs, rnn_hidden):
|
||||||
|
rnn_hidden = self.rnn(inputs, rnn_hidden)
|
||||||
|
outputs = self.relu(rnn_hidden)
|
||||||
|
outputs = self.linear(outputs)
|
||||||
|
outputs = self.sigmoid(outputs)
|
||||||
|
return outputs, rnn_hidden
|
||||||
|
|
5
train.py
5
train.py
|
@ -88,7 +88,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||||
current_lr_st = lr_decay(0.01, current_step, c.warmup_steps)
|
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||||
|
|
||||||
for params_group in optimizer.param_groups:
|
for params_group in optimizer.param_groups:
|
||||||
params_group['lr'] = current_lr
|
params_group['lr'] = current_lr
|
||||||
|
@ -363,8 +363,7 @@ def main(args):
|
||||||
c.r)
|
c.r)
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
optimizer_st = optim.SGD(model.decoder.stopnet.parameters(), lr=0.01,
|
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||||
momentum=0.9, nesterov=True)
|
|
||||||
|
|
||||||
criterion = L1LossMasked()
|
criterion = L1LossMasked()
|
||||||
criterion_st = nn.BCELoss()
|
criterion_st = nn.BCELoss()
|
||||||
|
|
Loading…
Reference in New Issue