mirror of https://github.com/coqui-ai/TTS.git
update align_tts.py model for the trainer
This commit is contained in:
parent
4f66e816d1
commit
c82d91051d
|
@ -4,6 +4,9 @@ import torch.nn as nn
|
||||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||||
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||||
|
@ -69,9 +72,19 @@ class AlignTTS(nn.Module):
|
||||||
hidden_channels=256,
|
hidden_channels=256,
|
||||||
hidden_channels_dp=256,
|
hidden_channels_dp=256,
|
||||||
encoder_type="fftransformer",
|
encoder_type="fftransformer",
|
||||||
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
encoder_params={
|
||||||
|
"hidden_channels_ffn": 1024,
|
||||||
|
"num_heads": 2,
|
||||||
|
"num_layers": 6,
|
||||||
|
"dropout_p": 0.1
|
||||||
|
},
|
||||||
decoder_type="fftransformer",
|
decoder_type="fftransformer",
|
||||||
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
decoder_params={
|
||||||
|
"hidden_channels_ffn": 1024,
|
||||||
|
"num_heads": 2,
|
||||||
|
"num_layers": 6,
|
||||||
|
"dropout_p": 0.1
|
||||||
|
},
|
||||||
length_scale=1,
|
length_scale=1,
|
||||||
num_speakers=0,
|
num_speakers=0,
|
||||||
external_c=False,
|
external_c=False,
|
||||||
|
@ -79,11 +92,15 @@ class AlignTTS(nn.Module):
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
self.phase = -1
|
||||||
|
self.length_scale = float(length_scale) if isinstance(
|
||||||
|
length_scale, int) else length_scale
|
||||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
||||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
encoder_params, c_in_channels)
|
||||||
|
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
|
||||||
|
decoder_params)
|
||||||
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
||||||
|
|
||||||
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||||
|
@ -104,9 +121,9 @@ class AlignTTS(nn.Module):
|
||||||
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||||
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||||
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
|
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
|
||||||
exponential = -0.5 * torch.mean(
|
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss(
|
||||||
torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1
|
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2),
|
||||||
) # B, L, T
|
dim=-1) # B, L, T
|
||||||
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
|
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
|
||||||
return logp
|
return logp
|
||||||
|
|
||||||
|
@ -140,7 +157,9 @@ class AlignTTS(nn.Module):
|
||||||
[1, 0, 0, 0, 0, 0, 0]]
|
[1, 0, 0, 0, 0, 0, 0]]
|
||||||
"""
|
"""
|
||||||
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
||||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
o_en_ex = torch.matmul(
|
||||||
|
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
||||||
|
2)).transpose(1, 2)
|
||||||
return o_en_ex, attn
|
return o_en_ex, attn
|
||||||
|
|
||||||
def format_durations(self, o_dr_log, x_mask):
|
def format_durations(self, o_dr_log, x_mask):
|
||||||
|
@ -174,7 +193,8 @@ class AlignTTS(nn.Module):
|
||||||
x_emb = torch.transpose(x_emb, 1, -1)
|
x_emb = torch.transpose(x_emb, 1, -1)
|
||||||
|
|
||||||
# compute sequence masks
|
# compute sequence masks
|
||||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
||||||
|
1).to(x.dtype)
|
||||||
|
|
||||||
# encoder pass
|
# encoder pass
|
||||||
o_en = self.encoder(x_emb, x_mask)
|
o_en = self.encoder(x_emb, x_mask)
|
||||||
|
@ -187,7 +207,8 @@ class AlignTTS(nn.Module):
|
||||||
return o_en, o_en_dp, x_mask, g
|
return o_en, o_en_dp, x_mask, g
|
||||||
|
|
||||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||||
|
1).to(o_en_dp.dtype)
|
||||||
# expand o_en with durations
|
# expand o_en with durations
|
||||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||||
# positional encoding
|
# positional encoding
|
||||||
|
@ -203,11 +224,13 @@ class AlignTTS(nn.Module):
|
||||||
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
|
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
|
||||||
# MAS potentials and alignment
|
# MAS potentials and alignment
|
||||||
mu, log_sigma = self.mdn_block(o_en)
|
mu, log_sigma = self.mdn_block(o_en)
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||||
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
|
1).to(o_en.dtype)
|
||||||
|
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
|
||||||
|
y_mask)
|
||||||
return dr_mas, mu, log_sigma, logp
|
return dr_mas, mu, log_sigma, logp
|
||||||
|
|
||||||
def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument
|
def forward(self, x, x_lengths, y, y_lengths, cond_input={"x_vectors": None}, phase=None): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T_max]
|
x: [B, T_max]
|
||||||
|
@ -216,47 +239,85 @@ class AlignTTS(nn.Module):
|
||||||
dr: [B, T_max]
|
dr: [B, T_max]
|
||||||
g: [B, C]
|
g: [B, C]
|
||||||
"""
|
"""
|
||||||
|
y = y.transpose(1, 2)
|
||||||
|
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
||||||
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
|
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
|
||||||
if phase == 0:
|
if phase == 0:
|
||||||
# train encoder and MDN
|
# train encoder and MDN
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
o_en, y, y_lengths, x_mask)
|
||||||
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||||
|
1).to(o_en_dp.dtype)
|
||||||
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
|
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
|
||||||
elif phase == 1:
|
elif phase == 1:
|
||||||
# train decoder
|
# train decoder
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||||
o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g)
|
o_de, attn = self._forward_decoder(o_en.detach(),
|
||||||
|
o_en_dp.detach(),
|
||||||
|
dr_mas.detach(),
|
||||||
|
x_mask,
|
||||||
|
y_lengths,
|
||||||
|
g=g)
|
||||||
elif phase == 2:
|
elif phase == 2:
|
||||||
# train the whole except duration predictor
|
# train the whole except duration predictor
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
o_en, y, y_lengths, x_mask)
|
||||||
|
o_de, attn = self._forward_decoder(o_en,
|
||||||
|
o_en_dp,
|
||||||
|
dr_mas,
|
||||||
|
x_mask,
|
||||||
|
y_lengths,
|
||||||
|
g=g)
|
||||||
elif phase == 3:
|
elif phase == 3:
|
||||||
# train duration predictor
|
# train duration predictor
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
o_dr_log = self.duration_predictor(x, x_mask)
|
o_dr_log = self.duration_predictor(x, x_mask)
|
||||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
o_en, y, y_lengths, x_mask)
|
||||||
|
o_de, attn = self._forward_decoder(o_en,
|
||||||
|
o_en_dp,
|
||||||
|
dr_mas,
|
||||||
|
x_mask,
|
||||||
|
y_lengths,
|
||||||
|
g=g)
|
||||||
o_dr_log = o_dr_log.squeeze(1)
|
o_dr_log = o_dr_log.squeeze(1)
|
||||||
else:
|
else:
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
o_en, y, y_lengths, x_mask)
|
||||||
|
o_de, attn = self._forward_decoder(o_en,
|
||||||
|
o_en_dp,
|
||||||
|
dr_mas,
|
||||||
|
x_mask,
|
||||||
|
y_lengths,
|
||||||
|
g=g)
|
||||||
o_dr_log = o_dr_log.squeeze(1)
|
o_dr_log = o_dr_log.squeeze(1)
|
||||||
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
|
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
|
||||||
return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp
|
outputs = {
|
||||||
|
'model_outputs': o_de.transpose(1, 2),
|
||||||
|
'alignments': attn,
|
||||||
|
'durations_log': o_dr_log,
|
||||||
|
'durations_mas_log': dr_mas_log,
|
||||||
|
'mu': mu,
|
||||||
|
'log_sigma': log_sigma,
|
||||||
|
'logp': logp
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
|
def inference(self, x, cond_input={'x_vectors': None}): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
x: [B, T_max]
|
x: [B, T_max]
|
||||||
x_lengths: [B]
|
x_lengths: [B]
|
||||||
g: [B, C]
|
g: [B, C]
|
||||||
"""
|
"""
|
||||||
|
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
|
||||||
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
# pad input to prevent dropping the last word
|
# pad input to prevent dropping the last word
|
||||||
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||||
|
@ -265,14 +326,91 @@ class AlignTTS(nn.Module):
|
||||||
# duration predictor pass
|
# duration predictor pass
|
||||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
y_lengths = o_dr.sum(1)
|
y_lengths = o_dr.sum(1)
|
||||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
o_de, attn = self._forward_decoder(o_en,
|
||||||
return o_de, attn
|
o_en_dp,
|
||||||
|
o_dr,
|
||||||
|
x_mask,
|
||||||
|
y_lengths,
|
||||||
|
g=g)
|
||||||
|
outputs = {'model_outputs': o_de.transpose(1, 2), 'alignments': attn}
|
||||||
|
return outputs
|
||||||
|
|
||||||
def load_checkpoint(
|
def train_step(self, batch: dict, criterion: nn.Module):
|
||||||
self, config, checkpoint_path, eval=False
|
text_input = batch['text_input']
|
||||||
): # pylint: disable=unused-argument, redefined-builtin
|
text_lengths = batch['text_lengths']
|
||||||
|
mel_input = batch['mel_input']
|
||||||
|
mel_lengths = batch['mel_lengths']
|
||||||
|
x_vectors = batch['x_vectors']
|
||||||
|
speaker_ids = batch['speaker_ids']
|
||||||
|
|
||||||
|
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids}
|
||||||
|
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, cond_input, self.phase)
|
||||||
|
loss_dict = criterion(
|
||||||
|
outputs['logp'],
|
||||||
|
outputs['model_outputs'],
|
||||||
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
outputs['durations_log'],
|
||||||
|
outputs['durations_mas_log'],
|
||||||
|
text_lengths,
|
||||||
|
phase=self.phase,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute alignment error (the lower the better )
|
||||||
|
align_error = 1 - alignment_diagonal_score(outputs['alignments'],
|
||||||
|
binary=True)
|
||||||
|
loss_dict["align_error"] = align_error
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
|
model_outputs = outputs['model_outputs']
|
||||||
|
alignments = outputs['alignments']
|
||||||
|
mel_input = batch['mel_input']
|
||||||
|
|
||||||
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
|
figures = {
|
||||||
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||||
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
|
return figures, train_audio
|
||||||
|
|
||||||
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
self.load_state_dict(state["model"])
|
self.load_state_dict(state["model"])
|
||||||
if eval:
|
if eval:
|
||||||
self.eval()
|
self.eval()
|
||||||
assert not self.training
|
assert not self.training
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _set_phase(config, global_step):
|
||||||
|
"""Decide AlignTTS training phase"""
|
||||||
|
if isinstance(config.phase_start_steps, list):
|
||||||
|
vals = [i < global_step for i in config.phase_start_steps]
|
||||||
|
if not True in vals:
|
||||||
|
phase = 0
|
||||||
|
else:
|
||||||
|
phase = (
|
||||||
|
len(config.phase_start_steps)
|
||||||
|
- [i < global_step for i in config.phase_start_steps][::-1].index(True)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
phase = None
|
||||||
|
return phase
|
||||||
|
|
||||||
|
def on_epoch_start(self, trainer):
|
||||||
|
"""Set AlignTTS training phase on epoch start."""
|
||||||
|
self.phase = self._set_phase(trainer.config, trainer.total_steps_done)
|
||||||
|
|
Loading…
Reference in New Issue