mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'upstream_clean' of https://github.com/geneing/TTS into geneing-upstream_clean
This commit is contained in:
commit
71af8da293
|
@ -127,8 +127,8 @@ class GravesAttention(nn.Module):
|
||||||
self.init_layers()
|
self.init_layers()
|
||||||
|
|
||||||
def init_layers(self):
|
def init_layers(self):
|
||||||
torch.nn.init.constant_(self.N_a[2].bias[10:15], 0.5)
|
torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.)
|
||||||
torch.nn.init.constant_(self.N_a[2].bias[5:10], 10)
|
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10)
|
||||||
|
|
||||||
def init_states(self, inputs):
|
def init_states(self, inputs):
|
||||||
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
if self.J is None or inputs.shape[1] > self.J.shape[-1]:
|
||||||
|
@ -159,20 +159,21 @@ class GravesAttention(nn.Module):
|
||||||
k_t = gbk_t[:, 2, :]
|
k_t = gbk_t[:, 2, :]
|
||||||
|
|
||||||
# attention GMM parameters
|
# attention GMM parameters
|
||||||
inv_sig_t = torch.exp(-torch.clamp(b_t, min=-6, max=9)) # variance
|
sig_t = torch.nn.functional.softplus(b_t)+self.eps
|
||||||
|
|
||||||
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
||||||
g_t = torch.softmax(g_t, dim=-1) * inv_sig_t + self.eps
|
g_t = torch.softmax(g_t, dim=-1) / sig_t + self.eps
|
||||||
|
|
||||||
# each B x K x T_in
|
# each B x K x T_in
|
||||||
g_t = g_t.unsqueeze(2).expand(g_t.size(0),
|
g_t = g_t.unsqueeze(2).expand(g_t.size(0),
|
||||||
g_t.size(1),
|
g_t.size(1),
|
||||||
inputs.size(1))
|
inputs.size(1))
|
||||||
inv_sig_t = inv_sig_t.unsqueeze(2).expand_as(g_t)
|
sig_t = sig_t.unsqueeze(2).expand_as(g_t)
|
||||||
mu_t_ = mu_t.unsqueeze(2).expand_as(g_t)
|
mu_t_ = mu_t.unsqueeze(2).expand_as(g_t)
|
||||||
j = self.J[:g_t.size(0), :, :inputs.size(1)]
|
j = self.J[:g_t.size(0), :, :inputs.size(1)]
|
||||||
|
|
||||||
# attention weights
|
# attention weights
|
||||||
phi_t = g_t * torch.exp(-0.5 * inv_sig_t * (mu_t_ - j)**2)
|
phi_t = g_t * torch.exp(-0.5 * (mu_t_ - j)**2 / (sig_t**2))
|
||||||
alpha_t = self.COEF * torch.sum(phi_t, 1)
|
alpha_t = self.COEF * torch.sum(phi_t, 1)
|
||||||
|
|
||||||
# apply masking
|
# apply masking
|
||||||
|
|
|
@ -55,8 +55,11 @@ def tts():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
if not config or not synthesizer:
|
|
||||||
args = create_argparser().parse_args()
|
args = create_argparser().parse_args()
|
||||||
|
|
||||||
|
# Setup synthesizer from CLI args if they're specified or no embedded model
|
||||||
|
# is present.
|
||||||
|
if not config or not synthesizer or args.tts_checkpoint or args.tts_config:
|
||||||
synthesizer = Synthesizer(args)
|
synthesizer = Synthesizer(args)
|
||||||
|
|
||||||
app.run(debug=config.debug, host='0.0.0.0', port=config.port)
|
app.run(debug=config.debug, host='0.0.0.0', port=config.port)
|
||||||
|
|
|
@ -53,15 +53,14 @@ class Synthesizer(object):
|
||||||
num_speakers = 0
|
num_speakers = 0
|
||||||
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
self.tts_model = setup_model(self.input_size, num_speakers=num_speakers, c=self.tts_config)
|
||||||
# load model state
|
# load model state
|
||||||
map_location = None if use_cuda else torch.device('cpu')
|
cp = torch.load(tts_checkpoint, map_location=torch.device('cpu'))
|
||||||
cp = torch.load(tts_checkpoint, map_location=map_location)
|
|
||||||
# load the model
|
# load the model
|
||||||
self.tts_model.load_state_dict(cp['model'])
|
self.tts_model.load_state_dict(cp['model'])
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
self.tts_model.cuda()
|
self.tts_model.cuda()
|
||||||
self.tts_model.eval()
|
self.tts_model.eval()
|
||||||
self.tts_model.decoder.max_decoder_steps = 3000
|
self.tts_model.decoder.max_decoder_steps = 3000
|
||||||
if 'r' in cp and self.tts_config.model in ["Tacotron", "TacotronGST"]:
|
if 'r' in cp:
|
||||||
self.tts_model.decoder.set_r(cp['r'])
|
self.tts_model.decoder.set_r(cp['r'])
|
||||||
|
|
||||||
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
def load_wavernn(self, lib_path, model_path, model_file, model_config, use_cuda):
|
||||||
|
|
Loading…
Reference in New Issue