mirror of https://github.com/coqui-ai/TTS.git
Implement `forward_tts`
- Generic API for feed-forward TTS models (FastPitch, SpeedySpeech) - Tests for `forward-tts` - Edit FastPitchConfig and SpeedySpeechConfig to use `forward_tts`
This commit is contained in:
parent
3c740d4893
commit
8b7e094bde
|
@ -2,12 +2,12 @@ from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
from TTS.tts.models.fast_pitch import FastPitchArgs
|
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FastPitchConfig(BaseTTSConfig):
|
class FastPitchConfig(BaseTTSConfig):
|
||||||
"""Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models.
|
"""Configure `ForwardTTS` as FastPitch model.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
@ -36,22 +36,43 @@ class FastPitchConfig(BaseTTSConfig):
|
||||||
d_vector_file (str):
|
d_vector_file (str):
|
||||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||||
|
|
||||||
noam_schedule (bool):
|
d_vector_dim (int):
|
||||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
Dimension of the external speaker embeddings. Defaults to 0.
|
||||||
|
|
||||||
warmup_steps (int):
|
optimizer (str):
|
||||||
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
Name of the model optimizer. Defaults to `Adam`.
|
||||||
|
|
||||||
|
optimizer_params (dict):
|
||||||
|
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
|
||||||
|
|
||||||
|
lr_scheduler (str):
|
||||||
|
Name of the learning rate scheduler. Defaults to `Noam`.
|
||||||
|
|
||||||
|
lr_scheduler_params (dict):
|
||||||
|
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
|
||||||
|
|
||||||
lr (float):
|
lr (float):
|
||||||
Initial learning rate. Defaults to `1e-3`.
|
Initial learning rate. Defaults to `1e-3`.
|
||||||
|
|
||||||
|
grad_clip (float):
|
||||||
|
Gradient norm clipping value. Defaults to `5.0`.
|
||||||
|
|
||||||
|
spec_loss_type (str):
|
||||||
|
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
|
||||||
|
|
||||||
|
duration_loss_type (str):
|
||||||
|
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `mse`.
|
||||||
|
|
||||||
|
use_ssim_loss (bool):
|
||||||
|
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
|
||||||
|
|
||||||
wd (float):
|
wd (float):
|
||||||
Weight decay coefficient. Defaults to `1e-7`.
|
Weight decay coefficient. Defaults to `1e-7`.
|
||||||
|
|
||||||
ssim_loss_alpha (float):
|
ssim_loss_alpha (float):
|
||||||
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
||||||
|
|
||||||
huber_loss_alpha (float):
|
dur_loss_alpha (float):
|
||||||
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
||||||
|
|
||||||
spec_loss_alpha (float):
|
spec_loss_alpha (float):
|
||||||
|
@ -73,9 +94,9 @@ class FastPitchConfig(BaseTTSConfig):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = "fast_pitch"
|
model: str = "forward_tts"
|
||||||
# model specific params
|
# model specific params
|
||||||
model_args: FastPitchArgs = field(default_factory=FastPitchArgs)
|
model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs)
|
||||||
|
|
||||||
# multi-speaker settings
|
# multi-speaker settings
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
|
@ -92,11 +113,13 @@ class FastPitchConfig(BaseTTSConfig):
|
||||||
grad_clip: float = 5.0
|
grad_clip: float = 5.0
|
||||||
|
|
||||||
# loss params
|
# loss params
|
||||||
|
spec_loss_type: str = "mse"
|
||||||
|
duration_loss_type: str = "mse"
|
||||||
|
use_ssim_loss: bool = True
|
||||||
ssim_loss_alpha: float = 1.0
|
ssim_loss_alpha: float = 1.0
|
||||||
dur_loss_alpha: float = 1.0
|
dur_loss_alpha: float = 1.0
|
||||||
spec_loss_alpha: float = 1.0
|
spec_loss_alpha: float = 1.0
|
||||||
pitch_loss_alpha: float = 1.0
|
pitch_loss_alpha: float = 1.0
|
||||||
dur_loss_alpha: float = 1.0
|
|
||||||
aligner_loss_alpha: float = 1.0
|
aligner_loss_alpha: float = 1.0
|
||||||
binary_align_loss_alpha: float = 1.0
|
binary_align_loss_alpha: float = 1.0
|
||||||
binary_align_loss_start_step: int = 20000
|
binary_align_loss_start_step: int = 20000
|
||||||
|
|
|
@ -2,81 +2,154 @@ from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||||
from TTS.tts.models.speedy_speech import SpeedySpeechArgs
|
from TTS.tts.models.forward_tts import ForwardTTSArgs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpeedySpeechConfig(BaseTTSConfig):
|
class SpeedySpeechConfig(BaseTTSConfig):
|
||||||
"""Defines parameters for Speedy Speech (feed-forward encoder-decoder) based models.
|
"""Configure `ForwardTTS` as SpeedySpeech model.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
>>> from TTS.tts.configs import SpeedySpeechConfig
|
>>> from TTS.tts.configs import SpeedySpeechConfig
|
||||||
>>> config = SpeedySpeechConfig()
|
>>> config = SpeedySpeechConfig()
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str):
|
model (str):
|
||||||
Model name used for selecting the right model at initialization. Defaults to `speedy_speech`.
|
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||||
|
|
||||||
model_args (Coqpit):
|
model_args (Coqpit):
|
||||||
Model class arguments. Check `SpeedySpeechArgs` for more details. Defaults to `SpeedySpeechArgs()`.
|
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||||
|
|
||||||
data_dep_init_steps (int):
|
data_dep_init_steps (int):
|
||||||
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
||||||
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
||||||
for the rest. Defaults to 10.
|
for the rest. Defaults to 10.
|
||||||
|
|
||||||
use_speaker_embedding (bool):
|
use_speaker_embedding (bool):
|
||||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||||
in the multi-speaker mode. Defaults to False.
|
in the multi-speaker mode. Defaults to False.
|
||||||
|
|
||||||
use_d_vector_file (bool):
|
use_d_vector_file (bool):
|
||||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||||
|
|
||||||
d_vector_file (str):
|
d_vector_file (str):
|
||||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||||
noam_schedule (bool):
|
|
||||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
d_vector_dim (int):
|
||||||
warmup_steps (int):
|
Dimension of the external speaker embeddings. Defaults to 0.
|
||||||
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
|
||||||
|
optimizer (str):
|
||||||
|
Name of the model optimizer. Defaults to `RAdam`.
|
||||||
|
|
||||||
|
optimizer_params (dict):
|
||||||
|
Arguments of the model optimizer. Defaults to `{"betas": [0.9, 0.998], "weight_decay": 1e-6}`.
|
||||||
|
|
||||||
|
lr_scheduler (str):
|
||||||
|
Name of the learning rate scheduler. Defaults to `Noam`.
|
||||||
|
|
||||||
|
lr_scheduler_params (dict):
|
||||||
|
Arguments of the learning rate scheduler. Defaults to `{"warmup_steps": 4000}`.
|
||||||
|
|
||||||
lr (float):
|
lr (float):
|
||||||
Initial learning rate. Defaults to `1e-3`.
|
Initial learning rate. Defaults to `1e-3`.
|
||||||
|
|
||||||
|
grad_clip (float):
|
||||||
|
Gradient norm clipping value. Defaults to `5.0`.
|
||||||
|
|
||||||
|
spec_loss_type (str):
|
||||||
|
Type of the spectrogram loss. Check `ForwardTTSLoss` for possible values. Defaults to `l1`.
|
||||||
|
|
||||||
|
duration_loss_type (str):
|
||||||
|
Type of the duration loss. Check `ForwardTTSLoss` for possible values. Defaults to `huber`.
|
||||||
|
|
||||||
|
use_ssim_loss (bool):
|
||||||
|
Enable/disable the use of SSIM (Structural Similarity) loss. Defaults to True.
|
||||||
|
|
||||||
wd (float):
|
wd (float):
|
||||||
Weight decay coefficient. Defaults to `1e-7`.
|
Weight decay coefficient. Defaults to `1e-7`.
|
||||||
ssim_alpha (float):
|
|
||||||
Weight for the SSIM loss. If set <= 0, disables the SSIM loss. Defaults to 1.0.
|
ssim_loss_alpha (float):
|
||||||
huber_alpha (float):
|
Weight for the SSIM loss. If set 0, disables the SSIM loss. Defaults to 1.0.
|
||||||
Weight for the duration predictor's loss. Defaults to 1.0.
|
|
||||||
l1_alpha (float):
|
dur_loss_alpha (float):
|
||||||
Weight for the L1 spectrogram loss. If set <= 0, disables the L1 loss. Defaults to 1.0.
|
Weight for the duration predictor's loss. If set 0, disables the huber loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
spec_loss_alpha (float):
|
||||||
|
Weight for the L1 spectrogram loss. If set 0, disables the L1 loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
binary_loss_alpha (float):
|
||||||
|
Weight for the binary loss. If set 0, disables the binary loss. Defaults to 1.0.
|
||||||
|
|
||||||
|
binary_align_loss_start_step (int):
|
||||||
|
Start binary alignment loss after this many steps. Defaults to 20000.
|
||||||
|
|
||||||
min_seq_len (int):
|
min_seq_len (int):
|
||||||
Minimum input sequence length to be used at training.
|
Minimum input sequence length to be used at training.
|
||||||
|
|
||||||
max_seq_len (int):
|
max_seq_len (int):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = "speedy_speech"
|
model: str = "forward_tts"
|
||||||
# model specific params
|
|
||||||
model_args: SpeedySpeechArgs = field(default_factory=SpeedySpeechArgs)
|
# set model args as SpeedySpeech
|
||||||
|
model_args: ForwardTTSArgs = ForwardTTSArgs(
|
||||||
|
use_pitch=False,
|
||||||
|
encoder_type="residual_conv_bn",
|
||||||
|
encoder_params={
|
||||||
|
"kernel_size": 4,
|
||||||
|
"dilations": 4 * [1, 2, 4] + [1],
|
||||||
|
"num_conv_blocks": 2,
|
||||||
|
"num_res_blocks": 13,
|
||||||
|
},
|
||||||
|
decoder_type="residual_conv_bn",
|
||||||
|
decoder_params={
|
||||||
|
"kernel_size": 4,
|
||||||
|
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||||
|
"num_conv_blocks": 2,
|
||||||
|
"num_res_blocks": 17,
|
||||||
|
},
|
||||||
|
out_channels=80,
|
||||||
|
hidden_channels=128,
|
||||||
|
num_speakers=0,
|
||||||
|
positional_encoding=True,
|
||||||
|
)
|
||||||
|
|
||||||
# multi-speaker settings
|
# multi-speaker settings
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
use_d_vector_file: bool = False
|
use_d_vector_file: bool = False
|
||||||
d_vector_file: str = False
|
d_vector_file: str = False
|
||||||
|
d_vector_dim: int = 0
|
||||||
|
|
||||||
# optimizer parameters
|
# optimizer parameters
|
||||||
optimizer: str = "RAdam"
|
optimizer: str = "RAdam"
|
||||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||||
lr_scheduler: str = None
|
lr_scheduler: str = "NoamLR"
|
||||||
lr_scheduler_params: dict = None
|
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||||
lr: float = 1e-4
|
lr: float = 1e-4
|
||||||
grad_clip: float = 5.0
|
grad_clip: float = 5.0
|
||||||
|
|
||||||
# loss params
|
# loss params
|
||||||
ssim_alpha: float = 1.0
|
spec_loss_type: str = "l1"
|
||||||
huber_alpha: float = 1.0
|
duration_loss_type: str = "huber"
|
||||||
l1_alpha: float = 1.0
|
use_ssim_loss: bool = True
|
||||||
|
ssim_loss_alpha: float = 1.0
|
||||||
|
dur_loss_alpha: float = 1.0
|
||||||
|
spec_loss_alpha: float = 1.0
|
||||||
|
aligner_loss_alpha: float = 1.0
|
||||||
|
binary_align_loss_alpha: float = 1.0
|
||||||
|
binary_align_loss_start_step: int = 20000
|
||||||
|
|
||||||
# overrides
|
# overrides
|
||||||
min_seq_len: int = 13
|
min_seq_len: int = 13
|
||||||
max_seq_len: int = 200
|
max_seq_len: int = 200
|
||||||
r: int = 1 # DO NOT CHANGE
|
r: int = 1 # DO NOT CHANGE
|
||||||
|
|
||||||
|
# dataset configs
|
||||||
|
compute_f0: bool = False
|
||||||
|
f0_cache_path: str = None
|
||||||
|
|
||||||
# testing
|
# testing
|
||||||
test_sentences: List[str] = field(
|
test_sentences: List[str] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
|
|
|
@ -0,0 +1,695 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from coqpit import Coqpit
|
||||||
|
from torch import nn
|
||||||
|
from torch.cuda.amp.autocast_mode import autocast
|
||||||
|
|
||||||
|
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||||
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||||
|
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
||||||
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||||
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||||
|
from TTS.tts.models.base_tts import BaseTTS
|
||||||
|
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_pitch, plot_spectrogram
|
||||||
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ForwardTTSArgs(Coqpit):
|
||||||
|
"""ForwardTTS Model arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
num_chars (int):
|
||||||
|
Number of characters in the vocabulary. Defaults to 100.
|
||||||
|
|
||||||
|
out_channels (int):
|
||||||
|
Number of output channels. Defaults to 80.
|
||||||
|
|
||||||
|
hidden_channels (int):
|
||||||
|
Number of base hidden channels of the model. Defaults to 512.
|
||||||
|
|
||||||
|
num_speakers (int):
|
||||||
|
Number of speakers for the speaker embedding layer. Defaults to 0.
|
||||||
|
|
||||||
|
use_aligner (bool):
|
||||||
|
Whether to use aligner network to learn the text to speech alignment or use pre-computed durations.
|
||||||
|
If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the
|
||||||
|
pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True.
|
||||||
|
|
||||||
|
use_pitch (bool):
|
||||||
|
Use pitch predictor to learn the pitch. Defaults to True.
|
||||||
|
|
||||||
|
duration_predictor_hidden_channels (int):
|
||||||
|
Number of hidden channels in the duration predictor. Defaults to 256.
|
||||||
|
|
||||||
|
duration_predictor_dropout_p (float):
|
||||||
|
Dropout rate for the duration predictor. Defaults to 0.1.
|
||||||
|
|
||||||
|
duration_predictor_kernel_size (int):
|
||||||
|
Kernel size of conv layers in the duration predictor. Defaults to 3.
|
||||||
|
|
||||||
|
pitch_predictor_hidden_channels (int):
|
||||||
|
Number of hidden channels in the pitch predictor. Defaults to 256.
|
||||||
|
|
||||||
|
pitch_predictor_dropout_p (float):
|
||||||
|
Dropout rate for the pitch predictor. Defaults to 0.1.
|
||||||
|
|
||||||
|
pitch_predictor_kernel_size (int):
|
||||||
|
Kernel size of conv layers in the pitch predictor. Defaults to 3.
|
||||||
|
|
||||||
|
pitch_embedding_kernel_size (int):
|
||||||
|
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
|
||||||
|
|
||||||
|
positional_encoding (bool):
|
||||||
|
Whether to use positional encoding. Defaults to True.
|
||||||
|
|
||||||
|
positional_encoding_use_scale (bool):
|
||||||
|
Whether to use a learnable scale coeff in the positional encoding. Defaults to True.
|
||||||
|
|
||||||
|
length_scale (int):
|
||||||
|
Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0.
|
||||||
|
|
||||||
|
encoder_type (str):
|
||||||
|
Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`.
|
||||||
|
Defaults to `fftransformer` as in the paper.
|
||||||
|
|
||||||
|
encoder_params (dict):
|
||||||
|
Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
||||||
|
|
||||||
|
decoder_type (str):
|
||||||
|
Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`.
|
||||||
|
Defaults to `fftransformer` as in the paper.
|
||||||
|
|
||||||
|
decoder_params (str):
|
||||||
|
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
||||||
|
|
||||||
|
use_d_vetor (bool):
|
||||||
|
Whether to use precomputed d-vectors for multi-speaker training. Defaults to False.
|
||||||
|
|
||||||
|
d_vector_dim (int):
|
||||||
|
Number of channels of the d-vectors. Defaults to 0.
|
||||||
|
|
||||||
|
detach_duration_predictor (bool):
|
||||||
|
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
|
||||||
|
does not pass to the earlier layers. Defaults to True.
|
||||||
|
|
||||||
|
max_duration (int):
|
||||||
|
Maximum duration accepted by the model. Defaults to 75.
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_chars: int = None
|
||||||
|
out_channels: int = 80
|
||||||
|
hidden_channels: int = 384
|
||||||
|
num_speakers: int = 0
|
||||||
|
use_aligner: bool = True
|
||||||
|
use_pitch: bool = True
|
||||||
|
pitch_predictor_hidden_channels: int = 256
|
||||||
|
pitch_predictor_kernel_size: int = 3
|
||||||
|
pitch_predictor_dropout_p: float = 0.1
|
||||||
|
pitch_embedding_kernel_size: int = 3
|
||||||
|
duration_predictor_hidden_channels: int = 256
|
||||||
|
duration_predictor_kernel_size: int = 3
|
||||||
|
duration_predictor_dropout_p: float = 0.1
|
||||||
|
positional_encoding: bool = True
|
||||||
|
poisitonal_encoding_use_scale: bool = True
|
||||||
|
length_scale: int = 1
|
||||||
|
encoder_type: str = "fftransformer"
|
||||||
|
encoder_params: dict = field(
|
||||||
|
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||||
|
)
|
||||||
|
decoder_type: str = "fftransformer"
|
||||||
|
decoder_params: dict = field(
|
||||||
|
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
||||||
|
)
|
||||||
|
use_d_vector: bool = False
|
||||||
|
d_vector_dim: int = 0
|
||||||
|
detach_duration_predictor: bool = False
|
||||||
|
max_duration: int = 75
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardTTS(BaseTTS):
|
||||||
|
"""General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment
|
||||||
|
network and a pitch predictor.
|
||||||
|
|
||||||
|
If the alignment network is used, the model learns the text-to-speech alignment
|
||||||
|
from the data instead of using pre-computed durations.
|
||||||
|
|
||||||
|
If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each
|
||||||
|
input character as in the FastPitch model.
|
||||||
|
|
||||||
|
`ForwardTTS` can be configured to one of these architectures,
|
||||||
|
|
||||||
|
- FastPitch
|
||||||
|
- SpeedySpeech
|
||||||
|
- FastSpeech
|
||||||
|
- TODO: FastSpeech2 (requires average speech energy predictor)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (Coqpit): Model coqpit class.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
|
||||||
|
>>> config = ForwardTTSArgs()
|
||||||
|
>>> model = ForwardTTS(config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=dangerous-default-value
|
||||||
|
def __init__(self, config: Coqpit):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# don't use isintance not to import recursively
|
||||||
|
if "Config" in config.__class__.__name__:
|
||||||
|
if "characters" in config:
|
||||||
|
# loading from FasrPitchConfig
|
||||||
|
_, self.config, num_chars = self.get_characters(config)
|
||||||
|
config.model_args.num_chars = num_chars
|
||||||
|
self.args = self.config.model_args
|
||||||
|
else:
|
||||||
|
# loading from ForwardTTSArgs
|
||||||
|
self.config = config
|
||||||
|
self.args = config.model_args
|
||||||
|
elif isinstance(config, ForwardTTSArgs):
|
||||||
|
self.args = config
|
||||||
|
self.config = config
|
||||||
|
else:
|
||||||
|
raise ValueError("config must be either a *Config or ForwardTTSArgs")
|
||||||
|
|
||||||
|
self.max_duration = self.args.max_duration
|
||||||
|
self.use_aligner = self.args.use_aligner
|
||||||
|
self.use_pitch = self.args.use_pitch
|
||||||
|
self.use_binary_alignment_loss = False
|
||||||
|
|
||||||
|
self.length_scale = (
|
||||||
|
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
||||||
|
|
||||||
|
self.encoder = Encoder(
|
||||||
|
self.args.hidden_channels,
|
||||||
|
self.args.hidden_channels,
|
||||||
|
self.args.encoder_type,
|
||||||
|
self.args.encoder_params,
|
||||||
|
self.args.d_vector_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.positional_encoding:
|
||||||
|
self.pos_encoder = PositionalEncoding(self.args.hidden_channels)
|
||||||
|
|
||||||
|
self.decoder = Decoder(
|
||||||
|
self.args.out_channels,
|
||||||
|
self.args.hidden_channels,
|
||||||
|
self.args.decoder_type,
|
||||||
|
self.args.decoder_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.duration_predictor = DurationPredictor(
|
||||||
|
self.args.hidden_channels + self.args.d_vector_dim,
|
||||||
|
self.args.duration_predictor_hidden_channels,
|
||||||
|
self.args.duration_predictor_kernel_size,
|
||||||
|
self.args.duration_predictor_dropout_p,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.use_pitch:
|
||||||
|
self.pitch_predictor = DurationPredictor(
|
||||||
|
self.args.hidden_channels + self.args.d_vector_dim,
|
||||||
|
self.args.pitch_predictor_hidden_channels,
|
||||||
|
self.args.pitch_predictor_kernel_size,
|
||||||
|
self.args.pitch_predictor_dropout_p,
|
||||||
|
)
|
||||||
|
self.pitch_emb = nn.Conv1d(
|
||||||
|
1,
|
||||||
|
self.args.hidden_channels,
|
||||||
|
kernel_size=self.args.pitch_embedding_kernel_size,
|
||||||
|
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.num_speakers > 1 and not self.args.use_d_vector:
|
||||||
|
# speaker embedding layer
|
||||||
|
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim)
|
||||||
|
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||||
|
|
||||||
|
if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels:
|
||||||
|
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
|
||||||
|
|
||||||
|
if self.args.use_aligner:
|
||||||
|
self.aligner = AlignmentNetwork(
|
||||||
|
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_attn(dr, x_mask, y_mask=None):
|
||||||
|
"""Generate an attention mask from the durations.
|
||||||
|
|
||||||
|
Shapes
|
||||||
|
- dr: :math:`(B, T_{en})`
|
||||||
|
- x_mask: :math:`(B, T_{en})`
|
||||||
|
- y_mask: :math:`(B, T_{de})`
|
||||||
|
"""
|
||||||
|
# compute decode mask from the durations
|
||||||
|
if y_mask is None:
|
||||||
|
y_lengths = dr.sum(1).long()
|
||||||
|
y_lengths[y_lengths < 1] = 1
|
||||||
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
|
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
||||||
|
return attn
|
||||||
|
|
||||||
|
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||||
|
"""Generate attention alignment map from durations and
|
||||||
|
expand encoder outputs
|
||||||
|
|
||||||
|
Shapes
|
||||||
|
- en: :math:`(B, D_{en}, T_{en})`
|
||||||
|
- dr: :math:`(B, T_{en})`
|
||||||
|
- x_mask: :math:`(B, T_{en})`
|
||||||
|
- y_mask: :math:`(B, T_{de})`
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- encoder output: :math:`[a,b,c,d]`
|
||||||
|
- durations: :math:`[1, 3, 2, 1]`
|
||||||
|
|
||||||
|
- expanded: :math:`[a, b, b, b, c, c, d]`
|
||||||
|
- attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]`
|
||||||
|
"""
|
||||||
|
attn = self.generate_attn(dr, x_mask, y_mask)
|
||||||
|
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
|
||||||
|
return o_en_ex, attn
|
||||||
|
|
||||||
|
def format_durations(self, o_dr_log, x_mask):
|
||||||
|
"""Format predicted durations.
|
||||||
|
1. Convert to linear scale from log scale
|
||||||
|
2. Apply the length scale for speed adjustment
|
||||||
|
3. Apply masking.
|
||||||
|
4. Cast 0 durations to 1.
|
||||||
|
5. Round the duration values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_dr_log: Log scale durations.
|
||||||
|
x_mask: Input text mask.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- o_dr_log: :math:`(B, T_{de})`
|
||||||
|
- x_mask: :math:`(B, T_{en})`
|
||||||
|
"""
|
||||||
|
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
|
||||||
|
o_dr[o_dr < 1] = 1.0
|
||||||
|
o_dr = torch.round(o_dr)
|
||||||
|
return o_dr
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _concat_speaker_embedding(o_en, g):
|
||||||
|
g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en]
|
||||||
|
o_en = torch.cat([o_en, g_exp], 1)
|
||||||
|
return o_en
|
||||||
|
|
||||||
|
def _sum_speaker_embedding(self, x, g):
|
||||||
|
# project g to decoder dim.
|
||||||
|
if hasattr(self, "proj_g"):
|
||||||
|
g = self.proj_g(g)
|
||||||
|
return x + g
|
||||||
|
|
||||||
|
def _forward_encoder(
|
||||||
|
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Encoding forward pass.
|
||||||
|
|
||||||
|
1. Embed speaker IDs if multi-speaker mode.
|
||||||
|
2. Embed character sequences.
|
||||||
|
3. Run the encoder network.
|
||||||
|
4. Concat speaker embedding to the encoder output for the duration predictor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.LongTensor): Input sequence IDs.
|
||||||
|
x_mask (torch.FloatTensor): Input squence mask.
|
||||||
|
g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
|
||||||
|
encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings,
|
||||||
|
character embeddings
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`(B, T_{en})`
|
||||||
|
- x_mask: :math:`(B, 1, T_{en})`
|
||||||
|
- g: :math:`(B, C)`
|
||||||
|
"""
|
||||||
|
if hasattr(self, "emb_g"):
|
||||||
|
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
|
||||||
|
if g is not None:
|
||||||
|
g = g.unsqueeze(-1)
|
||||||
|
# [B, T, C]
|
||||||
|
x_emb = self.emb(x)
|
||||||
|
# encoder pass
|
||||||
|
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
|
||||||
|
# speaker conditioning for duration predictor
|
||||||
|
if g is not None:
|
||||||
|
o_en_dp = self._concat_speaker_embedding(o_en, g)
|
||||||
|
else:
|
||||||
|
o_en_dp = o_en
|
||||||
|
return o_en, o_en_dp, x_mask, g, x_emb
|
||||||
|
|
||||||
|
def _forward_decoder(
|
||||||
|
self,
|
||||||
|
o_en: torch.FloatTensor,
|
||||||
|
dr: torch.IntTensor,
|
||||||
|
x_mask: torch.FloatTensor,
|
||||||
|
y_lengths: torch.IntTensor,
|
||||||
|
g: torch.FloatTensor,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Decoding forward pass.
|
||||||
|
|
||||||
|
1. Compute the decoder output mask
|
||||||
|
2. Expand encoder output with the durations.
|
||||||
|
3. Apply position encoding.
|
||||||
|
4. Add speaker embeddings if multi-speaker mode.
|
||||||
|
5. Run the decoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_en (torch.FloatTensor): Encoder output.
|
||||||
|
dr (torch.IntTensor): Ground truth durations or alignment network durations.
|
||||||
|
x_mask (torch.IntTensor): Input sequence mask.
|
||||||
|
y_lengths (torch.IntTensor): Output sequence lengths.
|
||||||
|
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
|
||||||
|
"""
|
||||||
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||||
|
# expand o_en with durations
|
||||||
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||||
|
# positional encoding
|
||||||
|
if hasattr(self, "pos_encoder"):
|
||||||
|
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||||
|
# speaker embedding
|
||||||
|
if g is not None:
|
||||||
|
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
|
||||||
|
# decoder pass
|
||||||
|
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||||
|
return o_de.transpose(1, 2), attn.transpose(1, 2)
|
||||||
|
|
||||||
|
def _forward_pitch_predictor(
|
||||||
|
self,
|
||||||
|
o_en: torch.FloatTensor,
|
||||||
|
x_mask: torch.IntTensor,
|
||||||
|
pitch: torch.FloatTensor = None,
|
||||||
|
dr: torch.IntTensor = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Pitch predictor forward pass.
|
||||||
|
|
||||||
|
1. Predict pitch from encoder outputs.
|
||||||
|
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
|
||||||
|
3. Embed average pitch values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_en (torch.FloatTensor): Encoder output.
|
||||||
|
x_mask (torch.IntTensor): Input sequence mask.
|
||||||
|
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
||||||
|
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- o_en: :math:`(B, C, T_{en})`
|
||||||
|
- x_mask: :math:`(B, 1, T_{en})`
|
||||||
|
- pitch: :math:`(B, 1, T_{de})`
|
||||||
|
- dr: :math:`(B, T_{en})`
|
||||||
|
"""
|
||||||
|
o_pitch = self.pitch_predictor(o_en, x_mask)
|
||||||
|
if pitch is not None:
|
||||||
|
avg_pitch = average_over_durations(pitch, dr)
|
||||||
|
o_pitch_emb = self.pitch_emb(avg_pitch)
|
||||||
|
return o_pitch_emb, o_pitch, avg_pitch
|
||||||
|
o_pitch_emb = self.pitch_emb(o_pitch)
|
||||||
|
return o_pitch_emb, o_pitch
|
||||||
|
|
||||||
|
def _forward_aligner(
|
||||||
|
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
|
||||||
|
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
"""Aligner forward pass.
|
||||||
|
|
||||||
|
1. Compute a mask to apply to the attention map.
|
||||||
|
2. Run the alignment network.
|
||||||
|
3. Apply MAS to compute the hard alignment map.
|
||||||
|
4. Compute the durations from the hard alignment map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.FloatTensor): Input sequence.
|
||||||
|
y (torch.FloatTensor): Output sequence.
|
||||||
|
x_mask (torch.IntTensor): Input sequence mask.
|
||||||
|
y_mask (torch.IntTensor): Output sequence mask.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials,
|
||||||
|
hard alignment map.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_en, C_en]`
|
||||||
|
- y: :math:`[B, T_de, C_de]`
|
||||||
|
- x_mask: :math:`[B, 1, T_en]`
|
||||||
|
- y_mask: :math:`[B, 1, T_de]`
|
||||||
|
|
||||||
|
- o_alignment_dur: :math:`[B, T_en]`
|
||||||
|
- alignment_soft: :math:`[B, T_en, T_de]`
|
||||||
|
- alignment_logprob: :math:`[B, 1, T_de, T_en]`
|
||||||
|
- alignment_mas: :math:`[B, T_en, T_de]`
|
||||||
|
"""
|
||||||
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||||
|
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
|
||||||
|
alignment_mas = maximum_path(
|
||||||
|
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
|
||||||
|
)
|
||||||
|
o_alignment_dur = torch.sum(alignment_mas, -1).int()
|
||||||
|
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
|
||||||
|
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.LongTensor,
|
||||||
|
x_lengths: torch.LongTensor,
|
||||||
|
y_lengths: torch.LongTensor,
|
||||||
|
y: torch.FloatTensor = None,
|
||||||
|
dr: torch.IntTensor = None,
|
||||||
|
pitch: torch.FloatTensor = None,
|
||||||
|
aux_input: Dict = {"d_vectors": None, "speaker_ids": None}, # pylint: disable=unused-argument
|
||||||
|
) -> Dict:
|
||||||
|
"""Model's forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.LongTensor): Input character sequences.
|
||||||
|
x_lengths (torch.LongTensor): Input sequence lengths.
|
||||||
|
y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None.
|
||||||
|
y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
|
||||||
|
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
|
||||||
|
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
|
||||||
|
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: :math:`[B, T_max]`
|
||||||
|
- x_lengths: :math:`[B]`
|
||||||
|
- y_lengths: :math:`[B]`
|
||||||
|
- y: :math:`[B, T_max2]`
|
||||||
|
- dr: :math:`[B, T_max]`
|
||||||
|
- g: :math:`[B, C]`
|
||||||
|
- pitch: :math:`[B, 1, T]`
|
||||||
|
"""
|
||||||
|
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||||
|
# compute sequence masks
|
||||||
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
|
||||||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
|
||||||
|
# encoder pass
|
||||||
|
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
|
||||||
|
# duration predictor pass
|
||||||
|
if self.args.detach_duration_predictor:
|
||||||
|
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||||
|
else:
|
||||||
|
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||||
|
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
||||||
|
# generate attn mask from predicted durations
|
||||||
|
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
||||||
|
# aligner
|
||||||
|
o_alignment_dur = None
|
||||||
|
alignment_soft = None
|
||||||
|
alignment_logprob = None
|
||||||
|
alignment_mas = None
|
||||||
|
if self.use_aligner:
|
||||||
|
o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
|
||||||
|
x_emb, y, x_mask, y_mask
|
||||||
|
)
|
||||||
|
alignment_soft = alignment_soft.transpose(1, 2)
|
||||||
|
alignment_mas = alignment_mas.transpose(1, 2)
|
||||||
|
dr = o_alignment_dur
|
||||||
|
# pitch predictor pass
|
||||||
|
o_pitch = None
|
||||||
|
avg_pitch = None
|
||||||
|
if self.args.use_pitch:
|
||||||
|
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
|
||||||
|
o_en = o_en + o_pitch_emb
|
||||||
|
# decoder pass
|
||||||
|
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)
|
||||||
|
outputs = {
|
||||||
|
"model_outputs": o_de, # [B, T, C]
|
||||||
|
"durations_log": o_dr_log.squeeze(1), # [B, T]
|
||||||
|
"durations": o_dr.squeeze(1), # [B, T]
|
||||||
|
"attn_durations": o_attn, # for visualization [B, T_en, T_de']
|
||||||
|
"pitch_avg": o_pitch,
|
||||||
|
"pitch_avg_gt": avg_pitch,
|
||||||
|
"alignments": attn, # [B, T_de, T_en]
|
||||||
|
"alignment_soft": alignment_soft,
|
||||||
|
"alignment_mas": alignment_mas,
|
||||||
|
"o_alignment_dur": o_alignment_dur,
|
||||||
|
"alignment_logprob": alignment_logprob,
|
||||||
|
"x_mask": x_mask,
|
||||||
|
"y_mask": y_mask,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||||
|
"""Model's inference pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.LongTensor): Input character sequence.
|
||||||
|
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
- x: [B, T_max]
|
||||||
|
- x_lengths: [B]
|
||||||
|
- g: [B, C]
|
||||||
|
"""
|
||||||
|
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||||
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||||
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
|
||||||
|
# encoder pass
|
||||||
|
o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
||||||
|
# duration predictor pass
|
||||||
|
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
|
||||||
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
|
y_lengths = o_dr.sum(1)
|
||||||
|
# pitch predictor pass
|
||||||
|
o_pitch = None
|
||||||
|
if self.args.use_pitch:
|
||||||
|
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
|
||||||
|
o_en = o_en + o_pitch_emb
|
||||||
|
# decoder pass
|
||||||
|
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g)
|
||||||
|
outputs = {
|
||||||
|
"model_outputs": o_de,
|
||||||
|
"alignments": attn,
|
||||||
|
"pitch": o_pitch,
|
||||||
|
"durations_log": o_dr_log,
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def train_step(self, batch: dict, criterion: nn.Module):
|
||||||
|
text_input = batch["text_input"]
|
||||||
|
text_lengths = batch["text_lengths"]
|
||||||
|
mel_input = batch["mel_input"]
|
||||||
|
mel_lengths = batch["mel_lengths"]
|
||||||
|
pitch = batch["pitch"] if self.args.use_pitch else None
|
||||||
|
d_vectors = batch["d_vectors"]
|
||||||
|
speaker_ids = batch["speaker_ids"]
|
||||||
|
durations = batch["durations"]
|
||||||
|
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||||
|
|
||||||
|
# forward pass
|
||||||
|
outputs = self.forward(
|
||||||
|
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
|
||||||
|
)
|
||||||
|
# use aligner's output as the duration target
|
||||||
|
if self.use_aligner:
|
||||||
|
durations = outputs["o_alignment_dur"]
|
||||||
|
# use float32 in AMP
|
||||||
|
with autocast(enabled=False):
|
||||||
|
# compute loss
|
||||||
|
loss_dict = criterion(
|
||||||
|
decoder_output=outputs["model_outputs"],
|
||||||
|
decoder_target=mel_input,
|
||||||
|
decoder_output_lens=mel_lengths,
|
||||||
|
dur_output=outputs["durations_log"],
|
||||||
|
dur_target=durations,
|
||||||
|
pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
|
||||||
|
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
||||||
|
input_lens=text_lengths,
|
||||||
|
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
||||||
|
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
|
||||||
|
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None,
|
||||||
|
)
|
||||||
|
# compute duration error
|
||||||
|
durations_pred = outputs["durations"]
|
||||||
|
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
|
||||||
|
loss_dict["duration_error"] = duration_error
|
||||||
|
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||||
|
model_outputs = outputs["model_outputs"]
|
||||||
|
alignments = outputs["alignments"]
|
||||||
|
mel_input = batch["mel_input"]
|
||||||
|
|
||||||
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||||
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
|
figures = {
|
||||||
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||||
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
# plot pitch figures
|
||||||
|
if self.args.use_pitch:
|
||||||
|
pitch = batch["pitch"]
|
||||||
|
pitch_avg_expanded, _ = self.expand_encoder_outputs(
|
||||||
|
outputs["pitch_avg"], outputs["durations"], outputs["x_mask"], outputs["y_mask"]
|
||||||
|
)
|
||||||
|
pitch = pitch[0, 0].data.cpu().numpy()
|
||||||
|
# TODO: denormalize before plotting
|
||||||
|
pitch = abs(pitch)
|
||||||
|
pitch_avg_expanded = abs(pitch_avg_expanded[0, 0]).data.cpu().numpy()
|
||||||
|
pitch_figures = {
|
||||||
|
"pitch_ground_truth": plot_pitch(pitch, gt_spec, ap, output_fig=False),
|
||||||
|
"pitch_avg_predicted": plot_pitch(pitch_avg_expanded, pred_spec, ap, output_fig=False),
|
||||||
|
}
|
||||||
|
figures.update(pitch_figures)
|
||||||
|
|
||||||
|
# plot the attention mask computed from the predicted durations
|
||||||
|
if "attn_durations" in outputs:
|
||||||
|
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
|
||||||
|
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
|
return figures, {"audio": train_audio}
|
||||||
|
|
||||||
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||||
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
|
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||||
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
self, config, checkpoint_path, eval=False
|
||||||
|
): # pylint: disable=unused-argument, redefined-builtin
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||||
|
self.load_state_dict(state["model"])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
|
||||||
|
def get_criterion(self):
|
||||||
|
from TTS.tts.layers.losses import ForwardTTSLoss # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
return ForwardTTSLoss(self.config)
|
||||||
|
|
||||||
|
def on_train_step_start(self, trainer):
|
||||||
|
"""Enable binary alignment loss when needed"""
|
||||||
|
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
|
||||||
|
self.use_binary_alignment_loss = True
|
|
@ -0,0 +1,149 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch as T
|
||||||
|
|
||||||
|
from TTS.tts.models.forward_tts import ForwardTTS, ForwardTTSArgs
|
||||||
|
from TTS.tts.utils.helpers import sequence_mask
|
||||||
|
|
||||||
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
|
|
||||||
|
def expand_encoder_outputs_test():
|
||||||
|
model = ForwardTTS(ForwardTTSArgs(num_chars=10))
|
||||||
|
|
||||||
|
inputs = T.rand(2, 5, 57)
|
||||||
|
durations = T.randint(1, 4, (2, 57))
|
||||||
|
|
||||||
|
x_mask = T.ones(2, 1, 57)
|
||||||
|
y_mask = T.ones(2, 1, durations.sum(1).max())
|
||||||
|
|
||||||
|
expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask)
|
||||||
|
|
||||||
|
for b in range(durations.shape[0]):
|
||||||
|
index = 0
|
||||||
|
for idx, dur in enumerate(durations[b]):
|
||||||
|
diff = (
|
||||||
|
expanded[b, :, index : index + dur.item()]
|
||||||
|
- inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape)
|
||||||
|
).sum()
|
||||||
|
assert abs(diff) < 1e-6, diff
|
||||||
|
index += dur
|
||||||
|
|
||||||
|
|
||||||
|
def model_input_output_test():
|
||||||
|
"""Assert the output shapes of the model in different modes"""
|
||||||
|
|
||||||
|
# VANILLA MODEL
|
||||||
|
model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=False, use_aligner=False))
|
||||||
|
|
||||||
|
x = T.randint(0, 10, (2, 21))
|
||||||
|
x_lengths = T.randint(10, 22, (2,))
|
||||||
|
x_lengths[-1] = 21
|
||||||
|
x_mask = sequence_mask(x_lengths).unsqueeze(1).long()
|
||||||
|
durations = T.randint(1, 4, (2, 21))
|
||||||
|
durations = durations * x_mask.squeeze(1)
|
||||||
|
y_lengths = durations.sum(1)
|
||||||
|
y_mask = sequence_mask(y_lengths).unsqueeze(1).long()
|
||||||
|
|
||||||
|
outputs = model.forward(x, x_lengths, y_lengths, dr=durations)
|
||||||
|
|
||||||
|
assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80)
|
||||||
|
assert outputs["durations_log"].shape == (2, 21)
|
||||||
|
assert outputs["durations"].shape == (2, 21)
|
||||||
|
assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21)
|
||||||
|
assert (outputs["x_mask"] - x_mask).sum() == 0.0
|
||||||
|
assert (outputs["y_mask"] - y_mask).sum() == 0.0
|
||||||
|
|
||||||
|
assert outputs["alignment_soft"] == None
|
||||||
|
assert outputs["alignment_mas"] == None
|
||||||
|
assert outputs["alignment_logprob"] == None
|
||||||
|
assert outputs["o_alignment_dur"] == None
|
||||||
|
assert outputs["pitch_avg"] == None
|
||||||
|
assert outputs["pitch_avg_gt"] == None
|
||||||
|
|
||||||
|
# USE PITCH
|
||||||
|
model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=True, use_aligner=False))
|
||||||
|
|
||||||
|
x = T.randint(0, 10, (2, 21))
|
||||||
|
x_lengths = T.randint(10, 22, (2,))
|
||||||
|
x_lengths[-1] = 21
|
||||||
|
x_mask = sequence_mask(x_lengths).unsqueeze(1).long()
|
||||||
|
durations = T.randint(1, 4, (2, 21))
|
||||||
|
durations = durations * x_mask.squeeze(1)
|
||||||
|
y_lengths = durations.sum(1)
|
||||||
|
y_mask = sequence_mask(y_lengths).unsqueeze(1).long()
|
||||||
|
pitch = T.rand(2, 1, y_lengths.max())
|
||||||
|
|
||||||
|
outputs = model.forward(x, x_lengths, y_lengths, dr=durations, pitch=pitch)
|
||||||
|
|
||||||
|
assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80)
|
||||||
|
assert outputs["durations_log"].shape == (2, 21)
|
||||||
|
assert outputs["durations"].shape == (2, 21)
|
||||||
|
assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21)
|
||||||
|
assert (outputs["x_mask"] - x_mask).sum() == 0.0
|
||||||
|
assert (outputs["y_mask"] - y_mask).sum() == 0.0
|
||||||
|
assert outputs["pitch_avg"].shape == (2, 1, 21)
|
||||||
|
assert outputs["pitch_avg_gt"].shape == (2, 1, 21)
|
||||||
|
|
||||||
|
assert outputs["alignment_soft"] == None
|
||||||
|
assert outputs["alignment_mas"] == None
|
||||||
|
assert outputs["alignment_logprob"] == None
|
||||||
|
assert outputs["o_alignment_dur"] == None
|
||||||
|
|
||||||
|
# USE ALIGNER NETWORK
|
||||||
|
model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=False, use_aligner=True))
|
||||||
|
|
||||||
|
x = T.randint(0, 10, (2, 21))
|
||||||
|
x_lengths = T.randint(10, 22, (2,))
|
||||||
|
x_lengths[-1] = 21
|
||||||
|
x_mask = sequence_mask(x_lengths).unsqueeze(1).long()
|
||||||
|
durations = T.randint(1, 4, (2, 21))
|
||||||
|
durations = durations * x_mask.squeeze(1)
|
||||||
|
y_lengths = durations.sum(1)
|
||||||
|
y_mask = sequence_mask(y_lengths).unsqueeze(1).long()
|
||||||
|
y = T.rand(2, y_lengths.max(), 80)
|
||||||
|
|
||||||
|
outputs = model.forward(x, x_lengths, y_lengths, dr=durations, y=y)
|
||||||
|
|
||||||
|
assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80)
|
||||||
|
assert outputs["durations_log"].shape == (2, 21)
|
||||||
|
assert outputs["durations"].shape == (2, 21)
|
||||||
|
assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21)
|
||||||
|
assert (outputs["x_mask"] - x_mask).sum() == 0.0
|
||||||
|
assert (outputs["y_mask"] - y_mask).sum() == 0.0
|
||||||
|
assert outputs["alignment_soft"].shape == (2, durations.sum(1).max(), 21)
|
||||||
|
assert outputs["alignment_mas"].shape == (2, durations.sum(1).max(), 21)
|
||||||
|
assert outputs["alignment_logprob"].shape == (2, 1, durations.sum(1).max(), 21)
|
||||||
|
assert outputs["o_alignment_dur"].shape == (2, 21)
|
||||||
|
|
||||||
|
assert outputs["pitch_avg"] == None
|
||||||
|
assert outputs["pitch_avg_gt"] == None
|
||||||
|
|
||||||
|
# USE ALIGNER NETWORK AND PITCH
|
||||||
|
model = ForwardTTS(ForwardTTSArgs(num_chars=10, use_pitch=True, use_aligner=True))
|
||||||
|
|
||||||
|
x = T.randint(0, 10, (2, 21))
|
||||||
|
x_lengths = T.randint(10, 22, (2,))
|
||||||
|
x_lengths[-1] = 21
|
||||||
|
x_mask = sequence_mask(x_lengths).unsqueeze(1).long()
|
||||||
|
durations = T.randint(1, 4, (2, 21))
|
||||||
|
durations = durations * x_mask.squeeze(1)
|
||||||
|
y_lengths = durations.sum(1)
|
||||||
|
y_mask = sequence_mask(y_lengths).unsqueeze(1).long()
|
||||||
|
y = T.rand(2, y_lengths.max(), 80)
|
||||||
|
pitch = T.rand(2, 1, y_lengths.max())
|
||||||
|
|
||||||
|
outputs = model.forward(x, x_lengths, y_lengths, dr=durations, pitch=pitch, y=y)
|
||||||
|
|
||||||
|
assert outputs["model_outputs"].shape == (2, durations.sum(1).max(), 80)
|
||||||
|
assert outputs["durations_log"].shape == (2, 21)
|
||||||
|
assert outputs["durations"].shape == (2, 21)
|
||||||
|
assert outputs["alignments"].shape == (2, durations.sum(1).max(), 21)
|
||||||
|
assert (outputs["x_mask"] - x_mask).sum() == 0.0
|
||||||
|
assert (outputs["y_mask"] - y_mask).sum() == 0.0
|
||||||
|
assert outputs["alignment_soft"].shape == (2, durations.sum(1).max(), 21)
|
||||||
|
assert outputs["alignment_mas"].shape == (2, durations.sum(1).max(), 21)
|
||||||
|
assert outputs["alignment_logprob"].shape == (2, 1, durations.sum(1).max(), 21)
|
||||||
|
assert outputs["o_alignment_dur"].shape == (2, 21)
|
||||||
|
assert outputs["pitch_avg"].shape == (2, 1, 21)
|
||||||
|
assert outputs["pitch_avg_gt"].shape == (2, 1, 21)
|
Loading…
Reference in New Issue