mirror of https://github.com/coqui-ai/TTS.git
config
This commit is contained in:
parent
e9bf49e1c3
commit
030e942396
|
@ -11,7 +11,7 @@
|
||||||
"embedding_size": 256,
|
"embedding_size": 256,
|
||||||
"text_cleaner": "english_cleaners",
|
"text_cleaner": "english_cleaners",
|
||||||
|
|
||||||
"epochs": 50,
|
"epochs": 500,
|
||||||
"lr": 0.002,
|
"lr": 0.002,
|
||||||
"warmup_steps": 4000,
|
"warmup_steps": 4000,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
|
|
|
@ -233,6 +233,8 @@ class Decoder(nn.Module):
|
||||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||||
# RNN_state -> |Linear| -> mel_spec
|
# RNN_state -> |Linear| -> mel_spec
|
||||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||||
|
self.stopnet = nn.Sequential(nn.Dropout(0.3), nn.Linear(80 * self.r, 1), nn.Sigmoid())
|
||||||
|
|
||||||
|
|
||||||
def forward(self, inputs, memory=None):
|
def forward(self, inputs, memory=None):
|
||||||
"""
|
"""
|
||||||
|
@ -277,6 +279,7 @@ class Decoder(nn.Module):
|
||||||
memory = memory.transpose(0, 1)
|
memory = memory.transpose(0, 1)
|
||||||
outputs = []
|
outputs = []
|
||||||
alignments = []
|
alignments = []
|
||||||
|
stop_tokens = []
|
||||||
t = 0
|
t = 0
|
||||||
memory_input = initial_memory
|
memory_input = initial_memory
|
||||||
while True:
|
while True:
|
||||||
|
@ -302,8 +305,11 @@ class Decoder(nn.Module):
|
||||||
output = decoder_input
|
output = decoder_input
|
||||||
# predict mel vectors from decoder vectors
|
# predict mel vectors from decoder vectors
|
||||||
output = self.proj_to_mel(output)
|
output = self.proj_to_mel(output)
|
||||||
|
# predict stop token
|
||||||
|
stop_token = self.stopnet(output)
|
||||||
outputs += [output]
|
outputs += [output]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
stop_tokens += [stop_token]
|
||||||
t += 1
|
t += 1
|
||||||
if (not greedy and self.training) or (greedy and memory is not None):
|
if (not greedy and self.training) or (greedy and memory is not None):
|
||||||
if t >= T_decoder:
|
if t >= T_decoder:
|
||||||
|
@ -319,7 +325,8 @@ class Decoder(nn.Module):
|
||||||
# Back to batch first
|
# Back to batch first
|
||||||
alignments = torch.stack(alignments).transpose(0, 1)
|
alignments = torch.stack(alignments).transpose(0, 1)
|
||||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||||
return outputs, alignments
|
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||||
|
return outputs, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def is_end_of_frames(output, alignment, eps=0.01): # 0.2
|
def is_end_of_frames(output, alignment, eps=0.01): # 0.2
|
||||||
|
|
|
@ -18,7 +18,6 @@ class Tacotron(nn.Module):
|
||||||
self.embedding.weight.data.normal_(0, 0.3)
|
self.embedding.weight.data.normal_(0, 0.3)
|
||||||
self.encoder = Encoder(embedding_dim)
|
self.encoder = Encoder(embedding_dim)
|
||||||
self.decoder = Decoder(256, mel_dim, r)
|
self.decoder = Decoder(256, mel_dim, r)
|
||||||
self.stopnet = nn.Sequential(nn.Linear(80, 1), nn.Sigmoid())
|
|
||||||
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim])
|
||||||
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
self.last_linear = nn.Linear(mel_dim * 2, linear_dim)
|
||||||
|
|
||||||
|
@ -28,12 +27,11 @@ class Tacotron(nn.Module):
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# batch x time x dim*r
|
# batch x time x dim*r
|
||||||
mel_outputs, alignments = self.decoder(
|
mel_outputs, alignments, stop_tokens = self.decoder(
|
||||||
encoder_outputs, mel_specs)
|
encoder_outputs, mel_specs)
|
||||||
# Reshape
|
# Reshape
|
||||||
# batch x time x dim
|
# batch x time x dim
|
||||||
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
|
||||||
stop_tokens = self.stopnet(mel_outputs).squeeze()
|
|
||||||
linear_outputs = self.postnet(mel_outputs)
|
linear_outputs = self.postnet(mel_outputs)
|
||||||
linear_outputs = self.last_linear(linear_outputs)
|
linear_outputs = self.last_linear(linear_outputs)
|
||||||
return mel_outputs, linear_outputs, alignments, stop_tokens
|
return mel_outputs, linear_outputs, alignments, stop_tokens
|
||||||
|
|
|
@ -23,7 +23,7 @@ def create_speech(m, s, CONFIG, use_cuda, ap):
|
||||||
torch.from_numpy(seq), volatile=True).unsqueeze(0)
|
torch.from_numpy(seq), volatile=True).unsqueeze(0)
|
||||||
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True)
|
# mel_var = torch.autograd.Variable(torch.from_numpy(mel).type(torch.FloatTensor), volatile=True)
|
||||||
|
|
||||||
mel_out, linear_out, alignments = m.forward(chars_var)
|
mel_out, linear_out, alignments, stop_tokens = m.forward(chars_var)
|
||||||
linear_out = linear_out[0].data.cpu().numpy()
|
linear_out = linear_out[0].data.cpu().numpy()
|
||||||
alignment = alignments[0].cpu().data.numpy()
|
alignment = alignments[0].cpu().data.numpy()
|
||||||
spec = ap._denormalize(linear_out)
|
spec = ap._denormalize(linear_out)
|
||||||
|
@ -31,7 +31,7 @@ def create_speech(m, s, CONFIG, use_cuda, ap):
|
||||||
wav = wav[:ap.find_endpoint(wav)]
|
wav = wav[:ap.find_endpoint(wav)]
|
||||||
out = io.BytesIO()
|
out = io.BytesIO()
|
||||||
ap.save_wav(wav, out)
|
ap.save_wav(wav, out)
|
||||||
return wav, alignment, spec
|
return wav, alignment, spec, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
def visualize(alignment, spectrogram, CONFIG):
|
def visualize(alignment, spectrogram, CONFIG):
|
||||||
|
|
2
train.py
2
train.py
|
@ -102,6 +102,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
||||||
linear_spec = linear_spec.cuda()
|
linear_spec = linear_spec.cuda()
|
||||||
stop_target = stop_target.cuda()
|
stop_target = stop_target.cuda()
|
||||||
|
|
||||||
|
stop_target = stop_target.view(B, stop_target.size(1) // c.r, -1)
|
||||||
|
stop_target = (stop_target.sum(1) > 0.0).long()
|
||||||
|
|
||||||
# create attention mask
|
# create attention mask
|
||||||
if c.mk > 0.0:
|
if c.mk > 0.0:
|
||||||
|
|
Loading…
Reference in New Issue