Enable aligner for FastPitch

This commit is contained in:
Eren Gölge 2021-07-25 01:24:17 +02:00
parent 81c228a2d8
commit e429afbce4
5 changed files with 163 additions and 15 deletions

View File

@ -76,6 +76,7 @@ class FastPitchConfig(BaseTTSConfig):
spec_loss_alpha: float = 1.0
pitch_loss_alpha: float = 1.0
dur_loss_alpha: float = 1.0
aligner_loss_alpha: float = 1.0
# overrides
min_seq_len: int = 13

View File

@ -47,8 +47,9 @@ def maximum_path(value, mask):
def maximum_path_cython(value, mask):
"""Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
Shapes:
- value: :math:`[B, T_en, T_de]`
- mask: :math:`[B, T_en, T_de]`
"""
value = value * mask
device = value.device

View File

@ -660,6 +660,36 @@ class VitsDiscriminatorLoss(nn.Module):
return return_dict
class ForwardSumLoss(nn.Module):
def __init__(self, blank_logprob=-1):
super().__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
self.blank_logprob = blank_logprob
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
total_loss = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
loss = self.ctc_loss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid : bid + 1],
target_lengths=key_lens[bid : bid + 1],
)
total_loss = total_loss + loss
total_loss = total_loss / attn_logprob.shape[0]
return total_loss
class FastPitchLoss(nn.Module):
def __init__(self, c):
super().__init__()
@ -667,11 +697,14 @@ class FastPitchLoss(nn.Module):
self.ssim = SSIMLoss()
self.dur_loss = MSELossMasked(False)
self.pitch_loss = MSELossMasked(False)
if c.model_args.use_aligner:
self.aligner_loss = ForwardSumLoss()
self.spec_loss_alpha = c.spec_loss_alpha
self.ssim_loss_alpha = c.ssim_loss_alpha
self.dur_loss_alpha = c.dur_loss_alpha
self.pitch_loss_alpha = c.pitch_loss_alpha
self.aligner_loss_alpha = c.aligner_loss_alpha
def forward(
self,
@ -683,29 +716,35 @@ class FastPitchLoss(nn.Module):
pitch_output,
pitch_target,
input_lens,
alignment_logprob=None,
):
loss = 0
return_dict = {}
if self.ssim_loss_alpha > 0:
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
loss += self.ssim_loss_alpha * ssim_loss
loss = loss + self.ssim_loss_alpha * ssim_loss
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
if self.spec_loss_alpha > 0:
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
loss += self.spec_loss_alpha * spec_loss
loss = loss + self.spec_loss_alpha * spec_loss
return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss
if self.dur_loss_alpha > 0:
log_dur_tgt = torch.log(dur_target.float() + 1)
dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens)
loss += self.dur_loss_alpha * dur_loss
loss = loss + self.dur_loss_alpha * dur_loss
return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss
if self.pitch_loss_alpha > 0:
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
loss += self.pitch_loss_alpha * pitch_loss
loss = loss + self.pitch_loss_alpha * pitch_loss
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
if self.aligner_loss_alpha > 0:
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
loss += self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
return_dict["loss"] = loss
return return_dict

View File

@ -4,8 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from coqpit import Coqpit
from matplotlib.pyplot import plot
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -14,6 +15,75 @@ from TTS.utils.audio import AudioProcessor
# pylint: disable=dangerous-default-value
class AlignmentEncoder(torch.nn.Module):
"""Module for alignment text and mel spectrogram."""
def __init__(
self,
in_query_channels=80,
in_key_channels=512,
attn_channels=80,
temperature=0.0005,
):
super().__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_proj = nn.Sequential(
ConvNorm(
in_key_channels, in_key_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False
),
torch.nn.ReLU(),
ConvNorm(in_key_channels * 2, attn_channels, kernel_size=1, bias=True, batch_norm=False),
)
self.query_proj = nn.Sequential(
ConvNorm(
in_query_channels, in_query_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False
),
torch.nn.ReLU(),
ConvNorm(in_query_channels * 2, in_query_channels, kernel_size=1, bias=True, batch_norm=False),
torch.nn.ReLU(),
ConvNorm(in_query_channels, attn_channels, kernel_size=1, bias=True, batch_norm=False),
)
def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
):
"""Forward pass of the aligner encoder.
Args:
queries (torch.tensor): query tensor.
keys (torch.tensor): key tensor.
mask (torch.tensor): uint8 binary mask for variable length entries (should be in the T2 domain).
attn_prior (torch.tensor): prior for attention matrix.
Shapes:
- queries: :math:`(B, C, T_de)`
- keys: :math:`(B, C_emb, T_en)`
- mask: :math:`(B, T_de)`
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1.
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
"""
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
queries_enc = self.query_proj(queries)
# Simplistic Gaussian Isotopic Attention
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
attn = -self.temperature * attn.sum(1, keepdim=True)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)
attn_logprob = attn.clone()
if mask is not None:
attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
attn = self.softmax(attn) # softmax along T2
return attn, attn_logprob
def mask_from_lens(lens, max_len: int = None):
if max_len is None:
max_len = lens.max()
@ -379,6 +449,7 @@ class FastPitchArgs(Coqpit):
detach_duration_predictor: bool = False
max_duration: int = 75
use_gt_duration: bool = True
use_aligner: bool = True
class FastPitch(BaseTTS):
@ -421,6 +492,7 @@ class FastPitch(BaseTTS):
self.max_duration = args.max_duration
self.use_gt_duration = args.use_gt_duration
self.use_aligner = args.use_aligner
self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale
@ -463,6 +535,9 @@ class FastPitch(BaseTTS):
self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True)
if args.use_aligner:
self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels)
@staticmethod
def expand_encoder_outputs(en, dr, x_mask, y_mask):
"""Generate attention alignment map from durations and
@ -483,10 +558,16 @@ class FastPitch(BaseTTS):
o_en_ex = torch.matmul(attn.transpose(1, 2), en)
return o_en_ex, attn.transpose(1, 2)
def forward(self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": 0, "speaker_ids": None}):
def forward(
self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None}
):
speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
o_alignment_dur = None
alignment_logprob = None
alignment_mas = None
# Calculate speaker embedding
# if self.speaker_emb is None:
@ -508,6 +589,16 @@ class FastPitch(BaseTTS):
o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# Aligner
if self.use_aligner:
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None)
alignment_mas = maximum_path(
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
)
o_alignment_dur = torch.log(1 + torch.sum(alignment_mas, -1))
avg_pitch = average_pitch(pitch, o_alignment_dur)
dr = o_alignment_dur
# TODO: move this to the dataset
avg_pitch = average_pitch(pitch, dr)
@ -529,6 +620,9 @@ class FastPitch(BaseTTS):
"pitch": o_pitch,
"pitch_gt": avg_pitch,
"alignments": attn,
"alignment_mas": alignment_mas,
"o_alignment_dur": o_alignment_dur,
"alignment_logprob": alignment_logprob,
}
return outputs
@ -599,7 +693,12 @@ class FastPitch(BaseTTS):
durations = batch["durations"]
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, pitch, aux_input)
outputs = self.forward(
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
)
if self.use_aligner:
durations = outputs["o_alignment_dur"]
# compute loss
loss_dict = criterion(
@ -611,6 +710,7 @@ class FastPitch(BaseTTS):
outputs["pitch"],
outputs["pitch_gt"],
text_lengths,
outputs["alignment_logprob"],
)
# compute duration error
@ -634,6 +734,10 @@ class FastPitch(BaseTTS):
"alignment": plot_alignment(align_img, output_fig=False),
}
if self.config.model_args.use_aligner and self.training:
alignment_mas = outputs["alignment_mas"]
figures["alignment_mas"] = plot_alignment(alignment_mas, ap, output_fig=False)
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, {"audio": train_audio}

View File

@ -11,7 +11,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
path=os.path.join(output_path, "../LJSpeech-1.1/"),
)
audio_config = BaseAudioConfig(
@ -46,14 +46,17 @@ config = FastPitchConfig(
print_eval=False,
mixed_precision=False,
output_path=output_path,
datasets=[dataset_config]
datasets=[dataset_config],
)
# compute alignments
manager = ModelManager()
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
# TODO: make compute_attention python callable
os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true")
if not config.model_args.use_aligner:
manager = ModelManager()
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
# TODO: make compute_attention python callable
os.system(
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
)
# train the model
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)