FastPitch refactor and commenting

This commit is contained in:
Eren Gölge 2021-09-03 13:26:49 +00:00
parent 59b24e66cf
commit 2bf9e83c49
4 changed files with 403 additions and 204 deletions

View File

@ -756,7 +756,7 @@ class FastPitchLoss(nn.Module):
loss = loss + self.aligner_loss_alpha * aligner_loss loss = loss + self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
if self.binary_alignment_loss_alpha > 0: if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple from typing import Dict, Tuple
import torch import torch
from coqpit import Coqpit from coqpit import Coqpit
@ -8,6 +8,7 @@ from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.aligner import AlignmentNetwork
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
@ -15,87 +16,101 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor from TTS.utils.audio import AudioProcessor
from TTS.utils.soft_dtw import SoftDTW
class AlignmentEncoder(torch.nn.Module):
def __init__(
self,
in_query_channels=80,
in_key_channels=512,
attn_channels=80,
temperature=0.0005,
):
super().__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_layer = nn.Sequential(
nn.Conv1d(
in_key_channels,
in_key_channels * 2,
kernel_size=3,
padding=1,
bias=True,
),
torch.nn.ReLU(),
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
)
self.query_layer = nn.Sequential(
nn.Conv1d(
in_query_channels,
in_query_channels * 2,
kernel_size=3,
padding=1,
bias=True,
),
torch.nn.ReLU(),
nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True),
torch.nn.ReLU(),
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True),
)
def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
) -> Tuple[torch.tensor, torch.tensor]:
"""Forward pass of the aligner encoder.
Shapes:
- queries: :math:`[B, C, T_de]`
- keys: :math:`[B, C_emb, T_en]`
- mask: :math:`[B, T_de]`
Output:
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
"""
key_out = self.key_layer(keys)
query_out = self.query_layer(queries)
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True)
if attn_prior is not None:
attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8)
if mask is not None:
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
attn = self.softmax(attn_logp)
return attn, attn_logp
@dataclass @dataclass
class FastPitchArgs(Coqpit): class FastPitchArgs(Coqpit):
"""Fast Pitch Model arguments.
Args:
num_chars (int):
Number of characters in the vocabulary. Defaults to 100.
out_channels (int):
Number of output channels. Defaults to 80.
hidden_channels (int):
Number of base hidden channels of the model. Defaults to 512.
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
duration_predictor_hidden_channels (int):
Number of hidden channels in the duration predictor. Defaults to 256.
duration_predictor_dropout_p (float):
Dropout rate for the duration predictor. Defaults to 0.1.
duration_predictor_kernel_size (int):
Kernel size of conv layers in the duration predictor. Defaults to 3.
pitch_predictor_hidden_channels (int):
Number of hidden channels in the pitch predictor. Defaults to 256.
pitch_predictor_dropout_p (float):
Dropout rate for the pitch predictor. Defaults to 0.1.
pitch_predictor_kernel_size (int):
Kernel size of conv layers in the pitch predictor. Defaults to 3.
pitch_embedding_kernel_size (int):
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
positional_encoding (bool):
Whether to use positional encoding. Defaults to True.
positional_encoding_use_scale (bool):
Whether to use a learnable scale coeff in the positional encoding. Defaults to True.
length_scale (int):
Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0.
encoder_type (str):
Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`.
Defaults to `fftransformer` as in the paper.
encoder_params (dict):
Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
decoder_type (str):
Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`.
Defaults to `fftransformer` as in the paper.
decoder_params (str):
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
use_d_vetor (bool):
Whether to use precomputed d-vectors for multi-speaker training. Defaults to False.
d_vector_dim (int):
Number of channels of the d-vectors. Defaults to 0.
detach_duration_predictor (bool):
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
does not pass to the earlier layers. Defaults to True.
max_duration (int):
Maximum duration accepted by the model. Defaults to 75.
use_aligner (bool):
Use aligner network to learn the text to speech alignment. Defaults to True.
"""
num_chars: int = None num_chars: int = None
out_channels: int = 80 out_channels: int = 80
hidden_channels: int = 256 hidden_channels: int = 384
num_speakers: int = 0 num_speakers: int = 0
duration_predictor_hidden_channels: int = 256 duration_predictor_hidden_channels: int = 256
duration_predictor_dropout: float = 0.1
duration_predictor_kernel_size: int = 3 duration_predictor_kernel_size: int = 3
duration_predictor_dropout_p: float = 0.1 duration_predictor_dropout_p: float = 0.1
pitch_predictor_hidden_channels: int = 256 pitch_predictor_hidden_channels: int = 256
pitch_predictor_dropout: float = 0.1
pitch_predictor_kernel_size: int = 3 pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1 pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3 pitch_embedding_kernel_size: int = 3
positional_encoding: bool = True positional_encoding: bool = True
poisitonal_encoding_use_scale: bool = True
length_scale: int = 1 length_scale: int = 1
encoder_type: str = "fftransformer" encoder_type: str = "fftransformer"
encoder_params: dict = field( encoder_params: dict = field(
@ -109,14 +124,16 @@ class FastPitchArgs(Coqpit):
d_vector_dim: int = 0 d_vector_dim: int = 0
detach_duration_predictor: bool = False detach_duration_predictor: bool = False
max_duration: int = 75 max_duration: int = 75
use_gt_duration: bool = True
use_aligner: bool = True use_aligner: bool = True
class FastPitch(BaseTTS): class FastPitch(BaseTTS):
"""FastPitch model. Very similart to SpeedySpeech model but with pitch prediction. """FastPitch model. Very similart to SpeedySpeech model but with pitch prediction.
Paper abstract: Paper::
https://arxiv.org/abs/2006.06873
Paper abstract::
We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental
frequency contours. The model predicts pitch contours during inference. By altering these predictions, frequency contours. The model predicts pitch contours during inference. By altering these predictions,
the generated speech can be more expressive, better match the semantic of the utterance, and in the end the generated speech can be more expressive, better match the semantic of the utterance, and in the end
@ -126,9 +143,6 @@ class FastPitch(BaseTTS):
and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time
factor for mel-spectrogram synthesis of a typical utterance." factor for mel-spectrogram synthesis of a typical utterance."
Notes:
TODO
Args: Args:
config (Coqpit): Model coqpit class. config (Coqpit): Model coqpit class.
@ -143,95 +157,138 @@ class FastPitch(BaseTTS):
super().__init__() super().__init__()
# don't use isintance not to import recursively
if config.__class__.__name__ == "FastPitchConfig":
if "characters" in config: if "characters" in config:
# loading from FasrPitchConfig # loading from FasrPitchConfig
_, self.config, num_chars = self.get_characters(config) _, self.config, num_chars = self.get_characters(config)
config.model_args.num_chars = num_chars config.model_args.num_chars = num_chars
args = self.config.model_args self.args = self.config.model_args
else: else:
# loading from FastPitchArgs # loading from FastPitchArgs
self.config = config self.config = config
args = config self.args = config.model_args
elif isinstance(config, FastPitchArgs):
self.args = config
self.config = config
else:
raise ValueError("config must be either a VitsConfig or Vitsself.args")
self.max_duration = args.max_duration self.max_duration = self.args.max_duration
self.use_gt_duration = args.use_gt_duration self.use_aligner = self.args.use_aligner
self.use_aligner = args.use_aligner self.use_binary_alignment_loss = False
self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale self.length_scale = (
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
self.encoder = Encoder(
config.model_args.hidden_channels,
config.model_args.hidden_channels,
config.model_args.encoder_type,
config.model_args.encoder_params,
config.model_args.d_vector_dim,
) )
if config.model_args.positional_encoding: self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels)
self.encoder = Encoder(
self.args.hidden_channels,
self.args.hidden_channels,
self.args.encoder_type,
self.args.encoder_params,
self.args.d_vector_dim,
)
if self.args.positional_encoding:
self.pos_encoder = PositionalEncoding(self.args.hidden_channels)
self.decoder = Decoder( self.decoder = Decoder(
config.model_args.out_channels, self.args.out_channels,
config.model_args.hidden_channels, self.args.hidden_channels,
config.model_args.decoder_type, self.args.decoder_type,
config.model_args.decoder_params, self.args.decoder_params,
) )
self.duration_predictor = DurationPredictor( self.duration_predictor = DurationPredictor(
config.model_args.hidden_channels + config.model_args.d_vector_dim, self.args.hidden_channels + self.args.d_vector_dim,
config.model_args.duration_predictor_hidden_channels, self.args.duration_predictor_hidden_channels,
config.model_args.duration_predictor_kernel_size, self.args.duration_predictor_kernel_size,
config.model_args.duration_predictor_dropout_p, self.args.duration_predictor_dropout_p,
) )
self.pitch_predictor = DurationPredictor( self.pitch_predictor = DurationPredictor(
config.model_args.hidden_channels + config.model_args.d_vector_dim, self.args.hidden_channels + self.args.d_vector_dim,
config.model_args.pitch_predictor_hidden_channels, self.args.pitch_predictor_hidden_channels,
config.model_args.pitch_predictor_kernel_size, self.args.pitch_predictor_kernel_size,
config.model_args.pitch_predictor_dropout_p, self.args.pitch_predictor_dropout_p,
) )
self.pitch_emb = nn.Conv1d( self.pitch_emb = nn.Conv1d(
1, 1,
config.model_args.hidden_channels, self.args.hidden_channels,
kernel_size=config.model_args.pitch_embedding_kernel_size, kernel_size=self.args.pitch_embedding_kernel_size,
padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2), padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
) )
if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: if self.args.num_speakers > 1 and not self.args.use_d_vector:
# speaker embedding layer # speaker embedding layer
self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels:
self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
if args.use_aligner: if self.args.use_aligner:
self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels) self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels)
@staticmethod @staticmethod
def expand_encoder_outputs(en, dr, x_mask, y_mask): def generate_attn(dr, x_mask, y_mask=None):
"""Generate an attention mask from the durations.
Shapes
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
"""
# compute decode mask from the durations
if y_mask is None:
y_lengths = dr.sum(1).long()
y_lengths[y_lengths < 1] = 1
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
"""Generate attention alignment map from durations and """Generate attention alignment map from durations and
expand encoder outputs expand encoder outputs
Example: Shapes
encoder output: [a,b,c,d] - en: :math:`(B, D_{en}, T_{en})`
durations: [1, 3, 2, 1] - dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
expanded: [a, b, b, b, c, c, d] Examples:
attention map: [[0, 0, 0, 0, 0, 0, 1], - encoder output: :math:`[a,b,c,d]`
[0, 0, 0, 0, 1, 1, 0], - durations: :math:`[1, 3, 2, 1]`
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]] - expanded: :math:`[a, b, b, b, c, c, d]`
- attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]`
""" """
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = self.generate_attn(dr, x_mask, y_mask)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask): def format_durations(self, o_dr_log, x_mask):
"""Format predicted durations.
1. Convert to linear scale from log scale
2. Apply the length scale for speed adjustment
3. Apply masking.
4. Cast 0 durations to 1.
5. Round the duration values.
Args:
o_dr_log: Log scale durations.
x_mask: Input text mask.
Shapes:
- o_dr_log: :math:`(B, T_{de})`
- x_mask: :math:`(B, T_{en})`
"""
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
o_dr[o_dr < 1] = 1.0 o_dr[o_dr < 1] = 1.0
o_dr = torch.round(o_dr) o_dr = torch.round(o_dr)
@ -249,22 +306,39 @@ class FastPitch(BaseTTS):
g = self.proj_g(g) g = self.proj_g(g)
return x + g return x + g
def _forward_encoder(self, x, x_lengths, g=None): def _forward_encoder(
self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Encoding forward pass.
1. Embed speaker IDs if multi-speaker mode.
2. Embed character sequences.
3. Run the encoder network.
4. Concat speaker embedding to the encoder output for the duration predictor.
Args:
x (torch.LongTensor): Input sequence IDs.
x_mask (torch.FloatTensor): Input squence mask.
g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None.
Returns:
Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings,
character embeddings
Shapes:
- x: :math:`(B, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- g: :math:`(B, C)`
"""
if hasattr(self, "emb_g"): if hasattr(self, "emb_g"):
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
if g is not None: if g is not None:
g = g.unsqueeze(-1) g = g.unsqueeze(-1)
# [B, T, C] # [B, T, C]
x_emb = self.emb(x) x_emb = self.emb(x)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
# encoder pass # encoder pass
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
# speaker conditioning for duration predictor # speaker conditioning for duration predictor
if g is not None: if g is not None:
o_en_dp = self._concat_speaker_embedding(o_en, g) o_en_dp = self._concat_speaker_embedding(o_en, g)
@ -272,8 +346,33 @@ class FastPitch(BaseTTS):
o_en_dp = o_en o_en_dp = o_en
return o_en, o_en_dp, x_mask, g, x_emb return o_en, o_en_dp, x_mask, g, x_emb
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): def _forward_decoder(
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) self,
o_en: torch.FloatTensor,
dr: torch.IntTensor,
x_mask: torch.FloatTensor,
y_lengths: torch.IntTensor,
g: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Decoding forward pass.
1. Compute the decoder output mask
2. Expand encoder output with the durations.
3. Apply position encoding.
4. Add speaker embeddings if multi-speaker mode.
5. Run the decoder.
Args:
o_en (torch.FloatTensor): Encoder output.
dr (torch.IntTensor): Ground truth durations or alignment network durations.
x_mask (torch.IntTensor): Input sequence mask.
y_lengths (torch.IntTensor): Output sequence lengths.
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
"""
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
# expand o_en with durations # expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding # positional encoding
@ -286,7 +385,34 @@ class FastPitch(BaseTTS):
o_de = self.decoder(o_en_ex, y_mask, g=g) o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de.transpose(1, 2), attn.transpose(1, 2) return o_de.transpose(1, 2), attn.transpose(1, 2)
def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None): def _forward_pitch_predictor(
self,
o_en: torch.FloatTensor,
x_mask: torch.IntTensor,
pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Pitch predictor forward pass.
1. Predict pitch from encoder outputs.
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
3. Embed average pitch values.
Args:
o_en (torch.FloatTensor): Encoder output.
x_mask (torch.IntTensor): Input sequence mask.
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
Shapes:
- o_en: :math:`(B, C, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- pitch: :math:`(B, 1, T_{de})`
- dr: :math:`(B, T_{en})`
"""
o_pitch = self.pitch_predictor(o_en, x_mask) o_pitch = self.pitch_predictor(o_en, x_mask)
if pitch is not None: if pitch is not None:
avg_pitch = average_pitch(pitch, dr) avg_pitch = average_pitch(pitch, dr)
@ -295,49 +421,111 @@ class FastPitch(BaseTTS):
o_pitch_emb = self.pitch_emb(o_pitch) o_pitch_emb = self.pitch_emb(o_pitch)
return o_pitch_emb, o_pitch return o_pitch_emb, o_pitch
def _forward_aligner(self, y, embedding, x_mask, y_mask): def _forward_aligner(
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Aligner forward pass.
1. Compute a mask to apply to the attention map.
2. Run the alignment network.
3. Apply MAS to compute the hard alignment map.
4. Compute the durations from the hard alignment map.
Args:
x (torch.FloatTensor): Input sequence.
y (torch.FloatTensor): Output sequence.
x_mask (torch.IntTensor): Input sequence mask.
y_mask (torch.IntTensor): Output sequence mask.
Returns:
Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials,
hard alignment map.
Shapes:
- x: :math:`[B, T_en, C_en]`
- y: :math:`[B, T_de, C_de]`
- x_mask: :math:`[B, 1, T_en]`
- y_mask: :math:`[B, 1, T_de]`
- o_alignment_dur: :math:`[B, T_en]`
- alignment_soft: :math:`[B, T_en, T_de]`
- alignment_logprob: :math:`[B, 1, T_de, T_en]`
- alignment_mas: :math:`[B, T_en, T_de]`
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
alignment_mas = maximum_path( alignment_mas = maximum_path(
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
) )
o_alignment_dur = torch.sum(alignment_mas, -1) o_alignment_dur = torch.sum(alignment_mas, -1).int()
return o_alignment_dur, alignment_logprob, alignment_mas alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
def forward( def forward(
self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None} self,
): # pylint: disable=unused-argument x: torch.LongTensor,
""" x_lengths: torch.LongTensor,
y_lengths: torch.LongTensor,
y: torch.FloatTensor = None,
dr: torch.IntTensor = None,
pitch: torch.FloatTensor = None,
aux_input: Dict = {"d_vectors": 0, "speaker_ids": None}, # pylint: disable=unused-argument
) -> Dict:
"""Model's forward pass.
Args:
x (torch.LongTensor): Input character sequences.
x_lengths (torch.LongTensor): Input sequence lengths.
y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None.
y (torch.FloatTensor): Spectrogram frames. Defaults to None.
dr (torch.IntTensor): Character durations over the spectrogram frames. Defaults to None.
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Defaults to None.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
Shapes: Shapes:
x: :math:`[B, T_max]` - x: :math:`[B, T_max]`
x_lengths: :math:`[B]` - x_lengths: :math:`[B]`
y_lengths: :math:`[B]` - y_lengths: :math:`[B]`
y: :math:`[B, T_max2]` - y: :math:`[B, T_max2]`
dr: :math:`[B, T_max]` - dr: :math:`[B, T_max]`
g: :math:`[B, C]` - g: :math:`[B, C]`
pitch: :math:`[B, 1, T]` - pitch: :math:`[B, 1, T]`
""" """
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) # compute sequence masks
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_lengths, g) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(y.dtype)
if self.config.model_args.detach_duration_predictor: x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(y.dtype)
# encoder pass
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
# duration predictor pass
if self.args.detach_duration_predictor:
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
else: else:
o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr_log = self.duration_predictor(o_en_dp, x_mask)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# generate attn mask from predicted durations
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
# aligner pass
if self.use_aligner: if self.use_aligner:
o_alignment_dur, alignment_logprob, alignment_mas = self._forward_aligner(y, x_emb, x_mask, y_mask) o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
x_emb, y, x_mask, y_mask
)
dr = o_alignment_dur dr = o_alignment_dur
# pitch predictor pass
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb o_en = o_en + o_pitch_emb
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) # decoder pass
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)
outputs = { outputs = {
"model_outputs": o_de, "model_outputs": o_de,
"durations_log": o_dr_log.squeeze(1), "durations_log": o_dr_log.squeeze(1),
"durations": o_dr.squeeze(1), "durations": o_dr.squeeze(1),
"attn_durations": o_attn, # for visualization
"pitch": o_pitch, "pitch": o_pitch,
"pitch_gt": avg_pitch, "pitch_gt": avg_pitch,
"alignments": attn, "alignments": attn,
"alignment_soft": alignment_soft.transpose(1, 2),
"alignment_mas": alignment_mas.transpose(1, 2), "alignment_mas": alignment_mas.transpose(1, 2),
"o_alignment_dur": o_alignment_dur, "o_alignment_dur": o_alignment_dur,
"alignment_logprob": alignment_logprob, "alignment_logprob": alignment_logprob,
@ -346,43 +534,33 @@ class FastPitch(BaseTTS):
@torch.no_grad() @torch.no_grad()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
""" """Model's inference pass.
Args:
x (torch.LongTensor): Input character sequence.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
Shapes: Shapes:
x: [B, T_max] - x: [B, T_max]
x_lengths: [B] - x_lengths: [B]
g: [B, C] - g: [B, C]
""" """
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device) x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# input sequence should be greated than the max convolution size x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
inference_padding = 5 # encoder pass
if x.shape[1] < 13: o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
inference_padding += 13 - x.shape[1]
# pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
# duration predictor pass # duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr_log = self.duration_predictor(o_en_dp, x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
# pitch predictor pass # pitch predictor pass
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
# if pitch_transform is not None:
# if self.pitch_std[0] == 0.0:
# # XXX LJSpeech-1.1 defaults
# mean, std = 218.14, 67.24
# else:
# mean, std = self.pitch_mean[0], self.pitch_std[0]
# pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
# if pitch_tgt is None:
# pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
# else:
# pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
o_en = o_en + o_pitch_emb o_en = o_en + o_pitch_emb
y_lengths = o_dr.sum(1) # decoder pass
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g)
outputs = { outputs = {
"model_outputs": o_de.transpose(1, 2), "model_outputs": o_de,
"alignments": attn, "alignments": attn,
"pitch": o_pitch, "pitch": o_pitch,
"durations_log": o_dr_log, "durations_log": o_dr_log,
@ -398,33 +576,35 @@ class FastPitch(BaseTTS):
d_vectors = batch["d_vectors"] d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"] speaker_ids = batch["speaker_ids"]
durations = batch["durations"] durations = batch["durations"]
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
# forward pass
outputs = self.forward( outputs = self.forward(
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
) )
# use aligner's output as the duration target
if self.use_aligner: if self.use_aligner:
durations = outputs["o_alignment_dur"] durations = outputs["o_alignment_dur"]
# use float32 in AMP
with autocast(enabled=False): # use float32 for the criterion with autocast(enabled=False):
# compute loss # compute loss
loss_dict = criterion( loss_dict = criterion(
outputs["model_outputs"], decoder_output=outputs["model_outputs"],
mel_input, decoder_target=mel_input,
mel_lengths, decoder_output_lens=mel_lengths,
outputs["durations_log"], dur_output=outputs["durations_log"],
durations, dur_target=durations,
outputs["pitch"], pitch_output=outputs["pitch"],
outputs["pitch_gt"], pitch_target=outputs["pitch_gt"],
text_lengths, input_lens=text_lengths,
outputs["alignment_logprob"], alignment_logprob=outputs["alignment_logprob"],
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None
) )
# compute duration error # compute duration error
durations_pred = outputs["durations"] durations_pred = outputs["durations"]
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
loss_dict["duration_error"] = duration_error loss_dict["duration_error"] = duration_error
return outputs, loss_dict return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
@ -442,9 +622,10 @@ class FastPitch(BaseTTS):
"alignment": plot_alignment(align_img, output_fig=False), "alignment": plot_alignment(align_img, output_fig=False),
} }
if self.config.model_args.use_aligner and self.training: # plot the attention mask computed from the predicted durations
alignment_mas = outputs["alignment_mas"][0].data.cpu().numpy() if "attn_durations" in outputs:
figures["alignment_mas"] = plot_alignment(alignment_mas, output_fig=False) alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
# Sample audio # Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T) train_audio = ap.inv_melspectrogram(pred_spec.T)
@ -470,8 +651,20 @@ class FastPitch(BaseTTS):
return FastPitchLoss(self.config) return FastPitchLoss(self.config)
def on_train_step_start(self, trainer):
"""Enable binary alignment loss when needed"""
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
self.use_binary_alignment_loss = True
def average_pitch(pitch, durs): def average_pitch(pitch, durs):
"""Compute the average pitch value for each input character based on the durations.
Shapes:
- pitch: :math:`[B, 1, T_de]`
- durs: :math:`[B, T_en]`
"""
durs_cums_ends = torch.cumsum(durs, dim=1).long() durs_cums_ends = torch.cumsum(durs, dim=1).long()
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0))

View File

@ -45,6 +45,7 @@
models/glow_tts.md models/glow_tts.md
models/vits.md models/vits.md
models/fast_pitch.md
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2

View File

@ -14,10 +14,11 @@ dataset_config = BaseDatasetConfig(
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
path=os.path.join(output_path, "../LJSpeech-1.1/"), path=os.path.join(output_path, "../LJSpeech-1.1/"),
) )
audio_config = BaseAudioConfig( audio_config = BaseAudioConfig(
sample_rate=22050, sample_rate=22050,
do_trim_silence=False, do_trim_silence=True,
trim_db=0.0, trim_db=60.0,
signal_norm=False, signal_norm=False,
mel_fmin=0.0, mel_fmin=0.0,
mel_fmax=8000, mel_fmax=8000,
@ -26,6 +27,7 @@ audio_config = BaseAudioConfig(
ref_level_db=20, ref_level_db=20,
preemphasis=0.0, preemphasis=0.0,
) )
config = FastPitchConfig( config = FastPitchConfig(
run_name="fast_pitch_ljspeech", run_name="fast_pitch_ljspeech",
audio=audio_config, audio=audio_config,
@ -33,6 +35,7 @@ config = FastPitchConfig(
eval_batch_size=16, eval_batch_size=16,
num_loader_workers=8, num_loader_workers=8,
num_eval_loader_workers=4, num_eval_loader_workers=4,
compute_input_seq_cache=True,
compute_f0=True, compute_f0=True,
f0_cache_path=os.path.join(output_path, "f0_cache"), f0_cache_path=os.path.join(output_path, "f0_cache"),
run_eval=True, run_eval=True,
@ -45,6 +48,8 @@ config = FastPitchConfig(
print_step=50, print_step=50,
print_eval=False, print_eval=False,
mixed_precision=False, mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path, output_path=output_path,
datasets=[dataset_config], datasets=[dataset_config],
) )