mirror of https://github.com/coqui-ai/TTS.git
update glow-tts for the trainer
This commit is contained in:
parent
3346a6d9dc
commit
f09ec7e3a7
|
@ -38,7 +38,6 @@ class GlowTTS(nn.Module):
|
||||||
encoder_params (dict): encoder module parameters.
|
encoder_params (dict): encoder module parameters.
|
||||||
speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_chars,
|
num_chars,
|
||||||
|
@ -133,27 +132,29 @@ class GlowTTS(nn.Module):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||||
# compute final values with the computed alignment
|
# compute final values with the computed alignment
|
||||||
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
y_mean = torch.matmul(
|
||||||
1, 2
|
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||||
y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose(
|
y_log_scale = torch.matmul(
|
||||||
1, 2
|
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
|
||||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||||
# compute total duration with adjustment
|
# compute total duration with adjustment
|
||||||
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
||||||
return y_mean, y_log_scale, o_attn_dur
|
return y_mean, y_log_scale, o_attn_dur
|
||||||
|
|
||||||
def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
def forward(self, x, x_lengths, y, y_lengths=None, cond_input={'x_vectors':None}):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T]
|
x: [B, T]
|
||||||
x_lenghts: B
|
x_lenghts: B
|
||||||
y: [B, C, T]
|
y: [B, T, C]
|
||||||
y_lengths: B
|
y_lengths: B
|
||||||
g: [B, C] or B
|
g: [B, C] or B
|
||||||
"""
|
"""
|
||||||
y_max_length = y.size(2)
|
y_max_length = y.size(2)
|
||||||
|
y = y.transpose(1, 2)
|
||||||
# norm speaker embeddings
|
# norm speaker embeddings
|
||||||
|
g = cond_input['x_vectors']
|
||||||
if g is not None:
|
if g is not None:
|
||||||
if self.speaker_embedding_dim:
|
if self.speaker_embedding_dim:
|
||||||
g = F.normalize(g).unsqueeze(-1)
|
g = F.normalize(g).unsqueeze(-1)
|
||||||
|
@ -161,29 +162,54 @@ class GlowTTS(nn.Module):
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||||
|
x_lengths,
|
||||||
|
g=g)
|
||||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
y, y_lengths, y_max_length, attn = self.preprocess(
|
||||||
|
y, y_lengths, y_max_length, None)
|
||||||
# create masks
|
# create masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||||
|
1).to(x_mask.dtype)
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# decoder pass
|
# decoder pass
|
||||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||||
# find the alignment path
|
# find the alignment path
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
o_scale = torch.exp(-2 * o_log_scale)
|
o_scale = torch.exp(-2 * o_log_scale)
|
||||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
|
||||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
[1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
|
||||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
|
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
|
||||||
|
z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
|
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
|
||||||
|
[1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
attn = maximum_path(logp,
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||||
|
attn, o_mean, o_log_scale, x_mask)
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
outputs = {
|
||||||
|
'model_outputs': z,
|
||||||
|
'logdet': logdet,
|
||||||
|
'y_mean': y_mean,
|
||||||
|
'y_log_scale': y_log_scale,
|
||||||
|
'alignments': attn,
|
||||||
|
'durations_log': o_dur_log,
|
||||||
|
'total_durations_log': o_attn_dur
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
def inference_with_MAS(self,
|
||||||
|
x,
|
||||||
|
x_lengths,
|
||||||
|
y=None,
|
||||||
|
y_lengths=None,
|
||||||
|
attn=None,
|
||||||
|
g=None):
|
||||||
"""
|
"""
|
||||||
It's similar to the teacher forcing in Tacotron.
|
It's similar to the teacher forcing in Tacotron.
|
||||||
It was proposed in: https://arxiv.org/abs/2104.05557
|
It was proposed in: https://arxiv.org/abs/2104.05557
|
||||||
|
@ -203,24 +229,33 @@ class GlowTTS(nn.Module):
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||||
|
x_lengths,
|
||||||
|
g=g)
|
||||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
y, y_lengths, y_max_length, attn = self.preprocess(
|
||||||
|
y, y_lengths, y_max_length, None)
|
||||||
# create masks
|
# create masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||||
|
1).to(x_mask.dtype)
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# decoder pass
|
# decoder pass
|
||||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||||
# find the alignment path between z and encoder output
|
# find the alignment path between z and encoder output
|
||||||
o_scale = torch.exp(-2 * o_log_scale)
|
o_scale = torch.exp(-2 * o_log_scale)
|
||||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
|
||||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
[1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
|
||||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
|
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
|
||||||
|
z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||||
|
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
|
||||||
|
[1]).unsqueeze(-1) # [b, t, 1]
|
||||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||||
|
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||||
|
attn, o_mean, o_log_scale, x_mask)
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
|
|
||||||
# get predited aligned distribution
|
# get predited aligned distribution
|
||||||
|
@ -228,8 +263,16 @@ class GlowTTS(nn.Module):
|
||||||
|
|
||||||
# reverse the decoder and predict using the aligned distribution
|
# reverse the decoder and predict using the aligned distribution
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
|
outputs = {
|
||||||
return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
'model_outputs': y,
|
||||||
|
'logdet': logdet,
|
||||||
|
'y_mean': y_mean,
|
||||||
|
'y_log_scale': y_log_scale,
|
||||||
|
'alignments': attn,
|
||||||
|
'durations_log': o_dur_log,
|
||||||
|
'total_durations_log': o_attn_dur
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decoder_inference(self, y, y_lengths=None, g=None):
|
def decoder_inference(self, y, y_lengths=None, g=None):
|
||||||
|
@ -247,7 +290,8 @@ class GlowTTS(nn.Module):
|
||||||
else:
|
else:
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||||
|
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||||
|
1).to(y.dtype)
|
||||||
|
|
||||||
# decoder pass
|
# decoder pass
|
||||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||||
|
@ -266,28 +310,98 @@ class GlowTTS(nn.Module):
|
||||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||||
|
|
||||||
# embedding pass
|
# embedding pass
|
||||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||||
|
x_lengths,
|
||||||
|
g=g)
|
||||||
# compute output durations
|
# compute output durations
|
||||||
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
||||||
w_ceil = torch.ceil(w)
|
w_ceil = torch.ceil(w)
|
||||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||||
y_max_length = None
|
y_max_length = None
|
||||||
# compute masks
|
# compute masks
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||||
|
1).to(x_mask.dtype)
|
||||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
# compute attention mask
|
# compute attention mask
|
||||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
attn = generate_path(w_ceil.squeeze(1),
|
||||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
attn_mask.squeeze(1)).unsqueeze(1)
|
||||||
|
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||||
|
attn, o_mean, o_log_scale, x_mask)
|
||||||
|
|
||||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask
|
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
|
||||||
|
self.inference_noise_scale) * y_mask
|
||||||
# decoder pass
|
# decoder pass
|
||||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||||
return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
outputs = {
|
||||||
|
'model_outputs': y,
|
||||||
|
'logdet': logdet,
|
||||||
|
'y_mean': y_mean,
|
||||||
|
'y_log_scale': y_log_scale,
|
||||||
|
'alignments': attn,
|
||||||
|
'durations_log': o_dur_log,
|
||||||
|
'total_durations_log': o_attn_dur
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def train_step(self, batch: dict, criterion: nn.Module):
|
||||||
|
"""Perform a single training step by fetching the right set if samples from the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch (dict): [description]
|
||||||
|
criterion (nn.Module): [description]
|
||||||
|
"""
|
||||||
|
text_input = batch['text_input']
|
||||||
|
text_lengths = batch['text_lengths']
|
||||||
|
mel_input = batch['mel_input']
|
||||||
|
mel_lengths = batch['mel_lengths']
|
||||||
|
x_vectors = batch['x_vectors']
|
||||||
|
|
||||||
|
outputs = self.forward(text_input,
|
||||||
|
text_lengths,
|
||||||
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
cond_input={"x_vectors": x_vectors})
|
||||||
|
|
||||||
|
loss_dict = criterion(outputs['model_outputs'], outputs['y_mean'],
|
||||||
|
outputs['y_log_scale'], outputs['logdet'],
|
||||||
|
mel_lengths, outputs['durations_log'],
|
||||||
|
outputs['total_durations_log'], text_lengths)
|
||||||
|
|
||||||
|
# compute alignment error (the lower the better )
|
||||||
|
align_error = 1 - alignment_diagonal_score(outputs['alignments'], binary=True)
|
||||||
|
loss_dict["align_error"] = align_error
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
|
model_outputs = outputs['model_outputs']
|
||||||
|
alignments = outputs['alignments']
|
||||||
|
mel_input = batch['mel_input']
|
||||||
|
|
||||||
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
|
figures = {
|
||||||
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||||
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
|
return figures, train_audio
|
||||||
|
|
||||||
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
||||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||||
if y_max_length is not None:
|
if y_max_length is not None:
|
||||||
y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze
|
y_max_length = (y_max_length //
|
||||||
|
self.num_squeeze) * self.num_squeeze
|
||||||
y = y[:, :, :y_max_length]
|
y = y[:, :, :y_max_length]
|
||||||
if attn is not None:
|
if attn is not None:
|
||||||
attn = attn[:, :, :, :y_max_length]
|
attn = attn[:, :, :, :y_max_length]
|
||||||
|
@ -297,9 +411,7 @@ class GlowTTS(nn.Module):
|
||||||
def store_inverse(self):
|
def store_inverse(self):
|
||||||
self.decoder.store_inverse()
|
self.decoder.store_inverse()
|
||||||
|
|
||||||
def load_checkpoint(
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
self, config, checkpoint_path, eval=False
|
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
Loading…
Reference in New Issue