mirror of https://github.com/coqui-ai/TTS.git
Enable aligner for FastPitch
This commit is contained in:
parent
81c228a2d8
commit
e429afbce4
|
@ -76,6 +76,7 @@ class FastPitchConfig(BaseTTSConfig):
|
|||
spec_loss_alpha: float = 1.0
|
||||
pitch_loss_alpha: float = 1.0
|
||||
dur_loss_alpha: float = 1.0
|
||||
aligner_loss_alpha: float = 1.0
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
|
|
|
@ -47,8 +47,9 @@ def maximum_path(value, mask):
|
|||
|
||||
def maximum_path_cython(value, mask):
|
||||
"""Cython optimised version.
|
||||
value: [b, t_x, t_y]
|
||||
mask: [b, t_x, t_y]
|
||||
Shapes:
|
||||
- value: :math:`[B, T_en, T_de]`
|
||||
- mask: :math:`[B, T_en, T_de]`
|
||||
"""
|
||||
value = value * mask
|
||||
device = value.device
|
||||
|
|
|
@ -660,6 +660,36 @@ class VitsDiscriminatorLoss(nn.Module):
|
|||
return return_dict
|
||||
|
||||
|
||||
class ForwardSumLoss(nn.Module):
|
||||
def __init__(self, blank_logprob=-1):
|
||||
super().__init__()
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
self.ctc_loss = torch.nn.CTCLoss(zero_infinity=True)
|
||||
self.blank_logprob = blank_logprob
|
||||
|
||||
def forward(self, attn_logprob, in_lens, out_lens):
|
||||
key_lens = in_lens
|
||||
query_lens = out_lens
|
||||
attn_logprob_padded = torch.nn.functional.pad(input=attn_logprob, pad=(1, 0), value=self.blank_logprob)
|
||||
|
||||
total_loss = 0.0
|
||||
for bid in range(attn_logprob.shape[0]):
|
||||
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
||||
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)[: query_lens[bid], :, : key_lens[bid] + 1]
|
||||
|
||||
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
||||
loss = self.ctc_loss(
|
||||
curr_logprob,
|
||||
target_seq,
|
||||
input_lengths=query_lens[bid : bid + 1],
|
||||
target_lengths=key_lens[bid : bid + 1],
|
||||
)
|
||||
total_loss = total_loss + loss
|
||||
|
||||
total_loss = total_loss / attn_logprob.shape[0]
|
||||
return total_loss
|
||||
|
||||
|
||||
class FastPitchLoss(nn.Module):
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
|
@ -667,11 +697,14 @@ class FastPitchLoss(nn.Module):
|
|||
self.ssim = SSIMLoss()
|
||||
self.dur_loss = MSELossMasked(False)
|
||||
self.pitch_loss = MSELossMasked(False)
|
||||
if c.model_args.use_aligner:
|
||||
self.aligner_loss = ForwardSumLoss()
|
||||
|
||||
self.spec_loss_alpha = c.spec_loss_alpha
|
||||
self.ssim_loss_alpha = c.ssim_loss_alpha
|
||||
self.dur_loss_alpha = c.dur_loss_alpha
|
||||
self.pitch_loss_alpha = c.pitch_loss_alpha
|
||||
self.aligner_loss_alpha = c.aligner_loss_alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -683,29 +716,35 @@ class FastPitchLoss(nn.Module):
|
|||
pitch_output,
|
||||
pitch_target,
|
||||
input_lens,
|
||||
alignment_logprob=None,
|
||||
):
|
||||
loss = 0
|
||||
return_dict = {}
|
||||
if self.ssim_loss_alpha > 0:
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
loss += self.ssim_loss_alpha * ssim_loss
|
||||
loss = loss + self.ssim_loss_alpha * ssim_loss
|
||||
return_dict["loss_ssim"] = self.ssim_loss_alpha * ssim_loss
|
||||
|
||||
if self.spec_loss_alpha > 0:
|
||||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||
loss += self.spec_loss_alpha * spec_loss
|
||||
loss = loss + self.spec_loss_alpha * spec_loss
|
||||
return_dict["loss_spec"] = self.spec_loss_alpha * spec_loss
|
||||
|
||||
if self.dur_loss_alpha > 0:
|
||||
log_dur_tgt = torch.log(dur_target.float() + 1)
|
||||
dur_loss = self.dur_loss(dur_output[:, :, None], log_dur_tgt[:, :, None], input_lens)
|
||||
loss += self.dur_loss_alpha * dur_loss
|
||||
loss = loss + self.dur_loss_alpha * dur_loss
|
||||
return_dict["loss_dur"] = self.dur_loss_alpha * dur_loss
|
||||
|
||||
if self.pitch_loss_alpha > 0:
|
||||
pitch_loss = self.pitch_loss(pitch_output.transpose(1, 2), pitch_target.transpose(1, 2), input_lens)
|
||||
loss += self.pitch_loss_alpha * pitch_loss
|
||||
loss = loss + self.pitch_loss_alpha * pitch_loss
|
||||
return_dict["loss_pitch"] = self.pitch_loss_alpha * pitch_loss
|
||||
|
||||
if self.aligner_loss_alpha > 0:
|
||||
aligner_loss = self.aligner_loss(alignment_logprob, input_lens, decoder_output_lens)
|
||||
loss += self.aligner_loss_alpha * aligner_loss
|
||||
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
|
||||
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
|
|
@ -4,8 +4,9 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from coqpit import Coqpit
|
||||
from matplotlib.pyplot import plot
|
||||
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
@ -14,6 +15,75 @@ from TTS.utils.audio import AudioProcessor
|
|||
# pylint: disable=dangerous-default-value
|
||||
|
||||
|
||||
class AlignmentEncoder(torch.nn.Module):
|
||||
"""Module for alignment text and mel spectrogram."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_query_channels=80,
|
||||
in_key_channels=512,
|
||||
attn_channels=80,
|
||||
temperature=0.0005,
|
||||
):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.softmax = torch.nn.Softmax(dim=3)
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
|
||||
self.key_proj = nn.Sequential(
|
||||
ConvNorm(
|
||||
in_key_channels, in_key_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(in_key_channels * 2, attn_channels, kernel_size=1, bias=True, batch_norm=False),
|
||||
)
|
||||
|
||||
self.query_proj = nn.Sequential(
|
||||
ConvNorm(
|
||||
in_query_channels, in_query_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(in_query_channels * 2, in_query_channels, kernel_size=1, bias=True, batch_norm=False),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(in_query_channels, attn_channels, kernel_size=1, bias=True, batch_norm=False),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
|
||||
):
|
||||
"""Forward pass of the aligner encoder.
|
||||
Args:
|
||||
queries (torch.tensor): query tensor.
|
||||
keys (torch.tensor): key tensor.
|
||||
mask (torch.tensor): uint8 binary mask for variable length entries (should be in the T2 domain).
|
||||
attn_prior (torch.tensor): prior for attention matrix.
|
||||
Shapes:
|
||||
- queries: :math:`(B, C, T_de)`
|
||||
- keys: :math:`(B, C_emb, T_en)`
|
||||
- mask: :math:`(B, T_de)`
|
||||
Output:
|
||||
attn (torch.tensor): B x 1 x T1 x T2 attention mask. Final dim T2 should sum to 1.
|
||||
attn_logprob (torch.tensor): B x 1 x T1 x T2 log-prob attention mask.
|
||||
"""
|
||||
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
|
||||
queries_enc = self.query_proj(queries)
|
||||
|
||||
# Simplistic Gaussian Isotopic Attention
|
||||
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2 # B x n_attn_dims x T1 x T2
|
||||
attn = -self.temperature * attn.sum(1, keepdim=True)
|
||||
|
||||
if attn_prior is not None:
|
||||
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + 1e-8)
|
||||
|
||||
attn_logprob = attn.clone()
|
||||
|
||||
if mask is not None:
|
||||
attn.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
|
||||
|
||||
attn = self.softmax(attn) # softmax along T2
|
||||
return attn, attn_logprob
|
||||
|
||||
|
||||
def mask_from_lens(lens, max_len: int = None):
|
||||
if max_len is None:
|
||||
max_len = lens.max()
|
||||
|
@ -379,6 +449,7 @@ class FastPitchArgs(Coqpit):
|
|||
detach_duration_predictor: bool = False
|
||||
max_duration: int = 75
|
||||
use_gt_duration: bool = True
|
||||
use_aligner: bool = True
|
||||
|
||||
|
||||
class FastPitch(BaseTTS):
|
||||
|
@ -421,6 +492,7 @@ class FastPitch(BaseTTS):
|
|||
|
||||
self.max_duration = args.max_duration
|
||||
self.use_gt_duration = args.use_gt_duration
|
||||
self.use_aligner = args.use_aligner
|
||||
|
||||
self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale
|
||||
|
||||
|
@ -463,6 +535,9 @@ class FastPitch(BaseTTS):
|
|||
|
||||
self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True)
|
||||
|
||||
if args.use_aligner:
|
||||
self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels)
|
||||
|
||||
@staticmethod
|
||||
def expand_encoder_outputs(en, dr, x_mask, y_mask):
|
||||
"""Generate attention alignment map from durations and
|
||||
|
@ -483,10 +558,16 @@ class FastPitch(BaseTTS):
|
|||
o_en_ex = torch.matmul(attn.transpose(1, 2), en)
|
||||
return o_en_ex, attn.transpose(1, 2)
|
||||
|
||||
def forward(self, x, x_lengths, y_lengths, dr, pitch, aux_input={"d_vectors": 0, "speaker_ids": None}):
|
||||
def forward(
|
||||
self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None}
|
||||
):
|
||||
speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
o_alignment_dur = None
|
||||
alignment_logprob = None
|
||||
alignment_mas = None
|
||||
|
||||
# Calculate speaker embedding
|
||||
# if self.speaker_emb is None:
|
||||
|
@ -508,6 +589,16 @@ class FastPitch(BaseTTS):
|
|||
o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr)
|
||||
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||
|
||||
# Aligner
|
||||
if self.use_aligner:
|
||||
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None)
|
||||
alignment_mas = maximum_path(
|
||||
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
|
||||
)
|
||||
o_alignment_dur = torch.log(1 + torch.sum(alignment_mas, -1))
|
||||
avg_pitch = average_pitch(pitch, o_alignment_dur)
|
||||
dr = o_alignment_dur
|
||||
|
||||
# TODO: move this to the dataset
|
||||
avg_pitch = average_pitch(pitch, dr)
|
||||
|
||||
|
@ -529,6 +620,9 @@ class FastPitch(BaseTTS):
|
|||
"pitch": o_pitch,
|
||||
"pitch_gt": avg_pitch,
|
||||
"alignments": attn,
|
||||
"alignment_mas": alignment_mas,
|
||||
"o_alignment_dur": o_alignment_dur,
|
||||
"alignment_logprob": alignment_logprob,
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
@ -599,7 +693,12 @@ class FastPitch(BaseTTS):
|
|||
durations = batch["durations"]
|
||||
|
||||
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, pitch, aux_input)
|
||||
outputs = self.forward(
|
||||
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
|
||||
)
|
||||
|
||||
if self.use_aligner:
|
||||
durations = outputs["o_alignment_dur"]
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
|
@ -611,6 +710,7 @@ class FastPitch(BaseTTS):
|
|||
outputs["pitch"],
|
||||
outputs["pitch_gt"],
|
||||
text_lengths,
|
||||
outputs["alignment_logprob"],
|
||||
)
|
||||
|
||||
# compute duration error
|
||||
|
@ -634,6 +734,10 @@ class FastPitch(BaseTTS):
|
|||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
if self.config.model_args.use_aligner and self.training:
|
||||
alignment_mas = outputs["alignment_mas"]
|
||||
figures["alignment_mas"] = plot_alignment(alignment_mas, ap, output_fig=False)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, {"audio": train_audio}
|
||||
|
|
|
@ -11,7 +11,7 @@ output_path = os.path.dirname(os.path.abspath(__file__))
|
|||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
||||
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
|
||||
path=os.path.join(output_path, "../LJSpeech-1.1/"),
|
||||
)
|
||||
audio_config = BaseAudioConfig(
|
||||
|
@ -46,14 +46,17 @@ config = FastPitchConfig(
|
|||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config]
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
|
||||
# compute alignments
|
||||
manager = ModelManager()
|
||||
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
|
||||
# TODO: make compute_attention python callable
|
||||
os.system(f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true")
|
||||
if not config.model_args.use_aligner:
|
||||
manager = ModelManager()
|
||||
model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA")
|
||||
# TODO: make compute_attention python callable
|
||||
os.system(
|
||||
f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true"
|
||||
)
|
||||
|
||||
# train the model
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config)
|
||||
|
|
Loading…
Reference in New Issue