Capacitron (#977)

* new CI config

* initial Capacitron implementation

* delete old unused file

* fix empty formatting changes

* update losses and training script

* fix previous commit

* fix commit

* Add Capacitron test and first round of test fixes

* revert formatter change

* add changes to the synthesizer

* add stepwise gradual lr scheduler and changes to the recipe

* add inference script for dev use

* feat: add posterior inference arguments to synth methods
- added reference wav and text args for posterior inference
- some formatting

* fix: add espeak flag to base_tts and dataset APIs
- use_espeak_phonemes flag was not implemented in those APIs
- espeak is now able to be utilised for phoneme generation
- necessary phonemizer for the Capacitron model

* chore: update training script and style
- training script includes the espeak flag and other hyperparams
- made style

* chore: fix linting

* feat: add Tacotron 2 support

* leftover from dev

* chore:rename parser args

* feat: extract optimizers
- created a separate optimizer class to merge the two optimizers

* chore: revert arbitrary trainer changes

* fmt: revert formatting bug

* formatting again

* formatting fixed

* fix: log func

* fix: update optimizer
- Implemented load_state_dict for continuing training

* fix: clean optimizer init for standard models

* improvement: purge espeak flags and add training scripts

* Delete capacitronT2.py

delete old training script, new one is pushed

* feat: capacitron trainer methods
- extracted capacitron specific training  operations from the trainer into custom
methods in taco1 and taco2 models

* chore: renaming and merging capacitron and gst style args

* fix: bug fixes from the previous commit

* fix: implement state_dict method on CapacitronOptimizer

* fix: call method

* fix: inference naming

* Delete train_capacitron.py

* fix: synthesize

* feat: update tests

* chore: fix style

* Delete capacitron_inference.py

* fix: fix train tts t2 capacitron tests

* fix: double forward in T2 train step

* fix: double forward in T1 train step

* fix: run make style

* fix: remove unused import

* fix: test for T1 capacitron

* fix: make lint

* feat: add blizzard2013 recipes

* make style

* fix: update recipes

* chore: make style

* Plot test sentences in Tacotron

* chore: make style and fix import

* fix: call forward first before problematic floordiv op

* fix: update recipes

* feat: add min_audio_len to recipes

* aux_input["style_mel"]

* chore: make style

* Make capacitron T2 recipe more stable

* Remove T1 capacitron Ljspeech

* feat: implement new grad clipping routine and update configs

* make style

* Add pretrained checkpoints

* Add default vocoder

* Change trainer package

* Fix grad clip issue for tacotron

* Fix scheduler issue with tacotron

Co-authored-by: Eren Gölge <egolge@coqui.ai>
Co-authored-by: WeberJulian <julian.weber@hotmail.fr>
Co-authored-by: Eren Gölge <erogol@hotmail.com>
This commit is contained in:
a-froghyar 2022-05-20 16:17:11 +02:00 committed by GitHub
parent ee99a6c1e2
commit 8be21ec387
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1194 additions and 39 deletions

View File

@ -119,6 +119,26 @@
"license": "apache 2.0",
"contact": "egolge@coqui.com"
}
},
"blizzard2013": {
"capacitron-t2-c50": {
"description": "Capacitron additions to Tacotron 2 with Capacity at 50 as in https://arxiv.org/pdf/1906.03402.pdf",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--en--blizzard2013--capacitron-t2-c50.zip",
"commit": "d6284e7",
"default_vocoder": "vocoder_models/en/blizzard2013/hifigan_v2",
"author": "Adam Froghyar @a-froghyar",
"license": "apache 2.0",
"contact": "adamfroghyar@gmail.com"
},
"capacitron-t2-c150": {
"description": "Capacitron additions to Tacotron 2 with Capacity at 150 as in https://arxiv.org/pdf/1906.03402.pdf",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--en--blizzard2013--capacitron-t2-c150.zip",
"commit": "d6284e7",
"default_vocoder": "vocoder_models/en/blizzard2013/hifigan_v2",
"author": "Adam Froghyar @a-froghyar",
"license": "apache 2.0",
"contact": "adamfroghyar@gmail.com"
}
}
},
"es": {
@ -379,6 +399,16 @@
"contact": "egolge@coqui.ai"
}
},
"blizzard2013": {
"hifigan_v2": {
"description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.",
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/vocoder_models--en--blizzard2013--hifigan_v2.zip",
"commit": "d6284e7",
"author": "Adam Froghyar @a-froghyar",
"license": "apache 2.0",
"contact": "adamfroghyar@gmail.com"
}
},
"vctk": {
"hifigan_v2": {
"description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts",

View File

@ -172,6 +172,10 @@ If you don't specify any models, then it uses LJSpeech based English model.
default=None,
)
parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None)
parser.add_argument(
"--capacitron_style_wav", type=str, help="Wav path file for Capacitron prosody reference.", default=None
)
parser.add_argument("--capacitron_style_text", type=str, help="Transcription of the reference.", default=None)
parser.add_argument(
"--list_speaker_idxs",
help="List available speaker ids for the defined multi-speaker model.",
@ -308,6 +312,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
args.language_idx,
args.speaker_wav,
reference_wav=args.reference_wav,
style_wav=args.capacitron_style_wav,
style_text=args.capacitron_style_text,
reference_speaker_name=args.reference_speaker_idx,
)

View File

@ -48,6 +48,50 @@ class GSTConfig(Coqpit):
check_argument("gst_num_style_tokens", c, restricted=True, min_val=1, max_val=1000)
@dataclass
class CapacitronVAEConfig(Coqpit):
"""Defines the capacitron VAE Module
Args:
capacitron_capacity (int):
Defines the variational capacity limit of the prosody embeddings. Defaults to 150.
capacitron_VAE_embedding_dim (int):
Defines the size of the Capacitron embedding vector dimension. Defaults to 128.
capacitron_use_text_summary_embeddings (bool):
If True, use a text summary embedding in Capacitron. Defaults to True.
capacitron_text_summary_embedding_dim (int):
Defines the size of the capacitron text embedding vector dimension. Defaults to 128.
capacitron_use_speaker_embedding (bool):
if True use speaker embeddings in Capacitron. Defaults to False.
capacitron_VAE_loss_alpha (float):
Weight for the VAE loss of the Tacotron model. If set less than or equal to zero, it disables the
corresponding loss function. Defaults to 0.25
capacitron_grad_clip (float):
Gradient clipping value for all gradients except beta. Defaults to 5.0
"""
capacitron_loss_alpha: int = 1
capacitron_capacity: int = 150
capacitron_VAE_embedding_dim: int = 128
capacitron_use_text_summary_embeddings: bool = True
capacitron_text_summary_embedding_dim: int = 128
capacitron_use_speaker_embedding: bool = False
capacitron_VAE_loss_alpha: float = 0.25
capacitron_grad_clip: float = 5.0
def check_values(
self,
):
"""Check config fields"""
c = asdict(self)
super().check_values()
check_argument("capacitron_capacity", c, restricted=True, min_val=10, max_val=500)
check_argument("capacitron_VAE_embedding_dim", c, restricted=True, min_val=16, max_val=1024)
check_argument("capacitron_use_speaker_embedding", c, restricted=False)
check_argument("capacitron_text_summary_embedding_dim", c, restricted=False, min_val=16, max_val=512)
check_argument("capacitron_VAE_loss_alpha", c, restricted=False)
check_argument("capacitron_grad_clip", c, restricted=False)
@dataclass
class CharactersConfig(Coqpit):
"""Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses.

View File

@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from typing import List
from TTS.tts.configs.shared_configs import BaseTTSConfig, GSTConfig
from TTS.tts.configs.shared_configs import BaseTTSConfig, CapacitronVAEConfig, GSTConfig
@dataclass
@ -23,6 +23,10 @@ class TacotronConfig(BaseTTSConfig):
gst_style_input (str):
Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and
this is not defined, the model uses a zero vector as an input. Defaults to None.
use_capacitron_vae (bool):
enable / disable the use of Capacitron modules. Defaults to False.
capacitron_vae (CapacitronConfig):
Instance of `CapacitronConfig` class.
num_chars (int):
Number of characters used by the model. It must be defined before initializing the model. Defaults to None.
num_speakers (int):
@ -143,6 +147,9 @@ class TacotronConfig(BaseTTSConfig):
gst: GSTConfig = None
gst_style_input: str = None
use_capacitron_vae: bool = False
capacitron_vae: CapacitronVAEConfig = None
# model specific params
num_speakers: int = 1
num_chars: int = 0

View File

@ -281,6 +281,10 @@ class TacotronLoss(torch.nn.Module):
def __init__(self, c, ga_sigma=0.4):
super().__init__()
self.stopnet_pos_weight = c.stopnet_pos_weight
self.use_capacitron_vae = c.use_capacitron_vae
if self.use_capacitron_vae:
self.capacitron_capacity = c.capacitron_vae.capacitron_capacity
self.capacitron_vae_loss_alpha = c.capacitron_vae.capacitron_VAE_loss_alpha
self.ga_alpha = c.ga_alpha
self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha
self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha
@ -308,6 +312,9 @@ class TacotronLoss(torch.nn.Module):
# pylint: disable=not-callable
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None
# For dev pruposes only
self.criterion_capacitron_reconstruction_loss = nn.L1Loss(reduction="sum")
def forward(
self,
postnet_output,
@ -317,6 +324,7 @@ class TacotronLoss(torch.nn.Module):
stopnet_output,
stopnet_target,
stop_target_length,
capacitron_vae_outputs,
output_lens,
decoder_b_output,
alignments,
@ -348,6 +356,55 @@ class TacotronLoss(torch.nn.Module):
return_dict["decoder_loss"] = decoder_loss
return_dict["postnet_loss"] = postnet_loss
if self.use_capacitron_vae:
# extract capacitron vae infos
posterior_distribution, prior_distribution, beta = capacitron_vae_outputs
# KL divergence term between the posterior and the prior
kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution))
# Limit the mutual information between the data and latent space by the variational capacity limit
kl_capacity = kl_term - self.capacitron_capacity
# pass beta through softplus to keep it positive
beta = torch.nn.functional.softplus(beta)[0]
# This is the term going to the main ADAM optimiser, we detach beta because
# beta is optimised by a separate, SGD optimiser below
capacitron_vae_loss = beta.detach() * kl_capacity
# normalize the capacitron_vae_loss as in L1Loss or MSELoss.
# After this, both the standard loss and capacitron_vae_loss will be in the same scale.
# For this reason we don't need use L1Loss and MSELoss in "sum" reduction mode.
# Note: the batch is not considered because the L1Loss was calculated in "sum" mode
# divided by the batch size, So not dividing the capacitron_vae_loss by B is legitimate.
# get B T D dimension from input
B, T, D = mel_input.size()
# normalize
if self.config.loss_masking:
# if mask loss get T using the mask
T = output_lens.sum() / B
# Only for dev purposes to be able to compare the reconstruction loss with the values in the
# original Capacitron paper
return_dict["capaciton_reconstruction_loss"] = (
self.criterion_capacitron_reconstruction_loss(decoder_output, mel_input) / decoder_output.size(0)
) + kl_capacity
capacitron_vae_loss = capacitron_vae_loss / (T * D)
capacitron_vae_loss = capacitron_vae_loss * self.capacitron_vae_loss_alpha
# This is the term to purely optimise beta and to pass into the SGD optimizer
beta_loss = torch.negative(beta) * kl_capacity.detach()
loss += capacitron_vae_loss
return_dict["capacitron_vae_loss"] = capacitron_vae_loss
return_dict["capacitron_vae_beta_loss"] = beta_loss
return_dict["capacitron_vae_kl_term"] = kl_term
return_dict["capacitron_beta"] = beta
stop_loss = (
self.criterion_st(stopnet_output, stopnet_target, stop_target_length)
if self.config.stopnet

View File

@ -0,0 +1,205 @@
import torch
from torch import nn
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
from torch.nn import functional as F
class CapacitronVAE(nn.Module):
"""Effective Use of Variational Embedding Capacity for prosody transfer.
See https://arxiv.org/abs/1906.03402"""
def __init__(
self,
num_mel,
capacitron_VAE_embedding_dim,
encoder_output_dim=256,
reference_encoder_out_dim=128,
speaker_embedding_dim=None,
text_summary_embedding_dim=None,
):
super().__init__()
# Init distributions
self.prior_distribution = MVN(
torch.zeros(capacitron_VAE_embedding_dim), torch.eye(capacitron_VAE_embedding_dim)
)
self.approximate_posterior_distribution = None
# define output ReferenceEncoder dim to the capacitron_VAE_embedding_dim
self.encoder = ReferenceEncoder(num_mel, out_dim=reference_encoder_out_dim)
# Init beta, the lagrange-like term for the KL distribution
self.beta = torch.nn.Parameter(torch.log(torch.exp(torch.Tensor([1.0])) - 1), requires_grad=True)
mlp_input_dimension = reference_encoder_out_dim
if text_summary_embedding_dim is not None:
self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim)
mlp_input_dimension += text_summary_embedding_dim
if speaker_embedding_dim is not None:
# TODO: Test a multispeaker model!
mlp_input_dimension += speaker_embedding_dim
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim)
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None):
# Use reference
if reference_mel_info is not None:
reference_mels = reference_mel_info[0] # [batch_size, num_frames, num_mels]
mel_lengths = reference_mel_info[1] # [batch_size]
enc_out = self.encoder(reference_mels, mel_lengths)
# concat speaker_embedding and/or text summary embedding
if text_info is not None:
text_inputs = text_info[0] # [batch_size, num_characters, num_embedding]
input_lengths = text_info[1]
text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device)
enc_out = torch.cat([enc_out, text_summary_out], dim=-1)
if speaker_embedding is not None:
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
# Feed the output of the ref encoder and information about text/speaker into
# an MLP to produce the parameteres for the approximate poterior distributions
mu, sigma = self.post_encoder_mlp(enc_out)
# convert to cpu because prior_distribution was created on cpu
mu = mu.cpu()
sigma = sigma.cpu()
# Sample from the posterior: z ~ q(z|x)
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma))
VAE_embedding = self.approximate_posterior_distribution.rsample()
# Infer from the model, bypasses encoding
else:
# Sample from the prior: z ~ p(z)
VAE_embedding = self.prior_distribution.sample().unsqueeze(0)
# reshape to [batch_size, 1, capacitron_VAE_embedding_dim]
return VAE_embedding.unsqueeze(1), self.approximate_posterior_distribution, self.prior_distribution, self.beta
class ReferenceEncoder(nn.Module):
"""NN module creating a fixed size prosody embedding from a spectrogram.
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
outputs: [batch_size, embedding_dim]
"""
def __init__(self, num_mel, out_dim):
super().__init__()
self.num_mel = num_mel
filters = [1] + [32, 32, 64, 64, 128, 128]
num_layers = len(filters) - 1
convs = [
nn.Conv2d(
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
)
for i in range(num_layers)
]
self.convs = nn.ModuleList(convs)
self.training = False
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
self.recurrence = nn.LSTM(
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
)
def forward(self, inputs, input_lengths):
batch_size = inputs.size(0)
x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel]
valid_lengths = input_lengths.float() # [batch_size]
for conv, bn in zip(self.convs, self.bns):
x = conv(x)
x = bn(x)
x = F.relu(x)
# Create the post conv width mask based on the valid lengths of the output of the convolution.
# The valid lengths for the output of a convolution on varying length inputs is
# ceil(input_length/stride) + 1 for stride=3 and padding=2
# For example (kernel_size=3, stride=2, padding=2):
# 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d
# _____
# x _____
# x _____
# x ____
# x
# x x x x -> Output valid length = 4
# Since every example in te batch is zero padded and therefore have separate valid_lengths,
# we need to mask off all the values AFTER the valid length for each example in the batch.
# Otherwise, the convolutions create noise and a lot of not real information
valid_lengths = (valid_lengths / 2).float()
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size]
post_conv_max_width = x.size(2)
mask = torch.arange(post_conv_max_width).to(inputs.device).expand(
len(valid_lengths), post_conv_max_width
) < valid_lengths.unsqueeze(1)
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1]
x = x * mask
x = x.transpose(1, 2)
# x: 4D tensor [batch_size, post_conv_width,
# num_channels==128, post_conv_height]
post_conv_width = x.size(1)
x = x.contiguous().view(batch_size, post_conv_width, -1)
# x: 3D tensor [batch_size, post_conv_width,
# num_channels*post_conv_height]
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
post_conv_input_lengths = valid_lengths
packed_seqs = nn.utils.rnn.pack_padded_sequence(
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.recurrence.flatten_parameters()
_, (ht, _) = self.recurrence(packed_seqs)
last_output = ht[-1]
return last_output.to(inputs.device) # [B, 128]
@staticmethod
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
for _ in range(n_convs):
height = (height - kernel_size + 2 * pad) // stride + 1
return height
class TextSummary(nn.Module):
def __init__(self, embedding_dim, encoder_output_dim):
super().__init__()
self.lstm = nn.LSTM(
encoder_output_dim, # text embedding dimension from the text encoder
embedding_dim, # fixed length output summary the lstm creates from the input
batch_first=True,
bidirectional=False,
)
def forward(self, inputs, input_lengths):
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
packed_seqs = nn.utils.rnn.pack_padded_sequence(
inputs, input_lengths.tolist(), batch_first=True, enforce_sorted=False
) # dynamic rnn sequence padding
self.lstm.flatten_parameters()
_, (ht, _) = self.lstm(packed_seqs)
last_output = ht[-1]
return last_output
class PostEncoderMLP(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.hidden_size = hidden_size
modules = [
nn.Linear(input_size, hidden_size), # Hidden Layer
nn.Tanh(),
nn.Linear(hidden_size, hidden_size * 2),
] # Output layer twice the size for mean and variance
self.net = nn.Sequential(*modules)
self.softplus = nn.Softplus()
def forward(self, _input):
mlp_output = self.net(_input)
# The mean parameter is unconstrained
mu = mlp_output[:, : self.hidden_size]
# The standard deviation must be positive. Parameterise with a softplus
sigma = self.softplus(mlp_output[:, self.hidden_size :])
return mu, sigma

View File

@ -139,7 +139,7 @@ class MultiHeadAttention(nn.Module):
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
# score = softmax(QK^T / (d_k ** 0.5))
# score = softmax(QK^T / (d_k**0.5))
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim**0.5)
scores = F.softmax(scores, dim=3)

View File

@ -1,6 +1,6 @@
import copy
from abc import abstractmethod
from typing import Dict
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
@ -10,7 +10,9 @@ from TTS.tts.layers.losses import TacotronLoss
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.generic_utils import format_aux_input
from TTS.utils.io import load_fsspec
from TTS.utils.training import gradual_training_scheduler
@ -47,6 +49,11 @@ class BaseTacotron(BaseTTS):
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
self.gst_layer = None
# Capacitron
if self.capacitron_vae and self.use_capacitron_vae:
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim
self.capacitron_vae_layer = None
# additional layers
self.decoder_backward = None
self.coarse_decoder = None
@ -125,6 +132,53 @@ class BaseTacotron(BaseTTS):
speaker_manager = SpeakerManager.init_from_config(config)
return BaseTacotron(config, ap, tokenizer, speaker_manager)
##########################
# TEST AND LOG FUNCTIONS #
##########################
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
"""Generic test run for `tts` models used by `Trainer`.
You can override this for a different behaviour.
Args:
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
test_sentences = self.config.test_sentences
aux_inputs = self._get_test_aux_input()
for idx, sen in enumerate(test_sentences):
outputs_dict = synthesis(
self,
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],
use_griffin_lim=True,
do_trim_silence=False,
)
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
)
test_figures["{}-alignment".format(idx)] = plot_alignment(
outputs_dict["outputs"]["alignments"], output_fig=False
)
return {"figures": test_figures, "audios": test_audios}
def test_log(
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
) -> None:
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
logger.test_figures(steps, outputs["figures"])
#############################
# COMMON COMPUTE FUNCTIONS
#############################
@ -160,7 +214,9 @@ class BaseTacotron(BaseTTS):
)
# scale_factor = self.decoder.r_init / self.decoder.r
alignments_backward = torch.nn.functional.interpolate(
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
alignments_backward.transpose(1, 2),
size=alignments.shape[1],
mode="nearest",
).transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
@ -193,6 +249,25 @@ class BaseTacotron(BaseTTS):
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
return inputs
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
"""Capacitron Variational Autoencoder"""
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
reference_mel_info,
text_info,
speaker_embedding, # pylint: disable=not-callable
)
VAE_outputs = VAE_outputs.to(inputs.device)
encoder_output = self._concat_speaker_embedding(
inputs, VAE_outputs
) # concatenate to the output of the basic tacotron encoder
return (
encoder_output,
posterior_distribution,
prior_distribution,
capacitron_beta,
)
@staticmethod
def _add_speaker_embedding(outputs, embedded_speakers):
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)

View File

@ -1,11 +1,13 @@
# coding: utf-8
from typing import Dict, List, Union
from typing import Dict, List, Tuple, Union
import torch
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
from TTS.tts.models.base_tacotron import BaseTacotron
@ -13,6 +15,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.capacitron_optimizer import CapacitronOptimizer
class Tacotron(BaseTacotron):
@ -51,6 +54,9 @@ class Tacotron(BaseTacotron):
if self.use_gst:
self.decoder_in_features += self.gst.gst_embedding_dim
if self.use_capacitron_vae:
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim
# embedding layer
self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0)
self.embedding.weight.data.normal_(0, 0.3)
@ -90,6 +96,20 @@ class Tacotron(BaseTacotron):
gst_embedding_dim=self.gst.gst_embedding_dim,
)
# Capacitron layers
if self.capacitron_vae and self.use_capacitron_vae:
self.capacitron_vae_layer = CapacitronVAE(
num_mel=self.decoder_output_dim,
encoder_output_dim=self.encoder_in_features,
capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim,
speaker_embedding_dim=self.embedded_speaker_dim
if self.use_speaker_embedding and self.capacitron_vae.capacitron_use_speaker_embedding
else None,
text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim
if self.capacitron_vae.capacitron_use_text_summary_embeddings
else None,
)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
@ -146,6 +166,19 @@ class Tacotron(BaseTacotron):
# B x 1 x speaker_embed_dim
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
# Capacitron
if self.capacitron_vae and self.use_capacitron_vae:
# B x capacitron_VAE_embedding_dim
encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding(
encoder_outputs,
reference_mel_info=[mel_specs, mel_lengths],
text_info=[inputs, text_lengths]
if self.capacitron_vae.capacitron_use_text_summary_embeddings
else None,
speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None,
)
else:
capacitron_vae_outputs = None
# decoder_outputs: B x decoder_in_features x T_out
# alignments: B x T_in x encoder_in_features
# stop_tokens: B x T_in
@ -178,6 +211,7 @@ class Tacotron(BaseTacotron):
"decoder_outputs": decoder_outputs,
"alignments": alignments,
"stop_tokens": stop_tokens,
"capacitron_vae_outputs": capacitron_vae_outputs,
}
)
return outputs
@ -190,6 +224,28 @@ class Tacotron(BaseTacotron):
if self.gst and self.use_gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
if self.capacitron_vae and self.use_capacitron_vae:
if aux_input["style_text"] is not None:
style_text_embedding = self.embedding(aux_input["style_text"])
style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to(
encoder_outputs.device
) # pylint: disable=not-callable
reference_mel_length = (
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
if aux_input["style_mel"] is not None
else None
) # pylint: disable=not-callable
# B x capacitron_VAE_embedding_dim
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
encoder_outputs,
reference_mel_info=[aux_input["style_mel"], reference_mel_length]
if aux_input["style_mel"] is not None
else None,
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
speaker_embedding=aux_input["d_vectors"]
if self.capacitron_vae.capacitron_use_speaker_embedding
else None,
)
if self.num_speakers > 1:
if not self.use_d_vector_file:
# B x 1 x speaker_embed_dim
@ -215,12 +271,19 @@ class Tacotron(BaseTacotron):
}
return outputs
def train_step(self, batch, criterion):
"""Perform a single training step by fetching the right set if samples from the batch.
def before_backward_pass(self, loss_dict, optimizer) -> None:
# Extracting custom training specific operations for capacitron
# from the trainer
if self.use_capacitron_vae:
loss_dict["capacitron_vae_beta_loss"].backward()
optimizer.first_step()
def train_step(self, batch: Dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]:
"""Perform a single training step by fetching the right set of samples from the batch.
Args:
batch ([type]): [description]
criterion ([type]): [description]
batch ([Dict]): A dictionary of input tensors.
criterion ([torch.nn.Module]): Callable criterion to compute model loss.
"""
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
@ -232,14 +295,8 @@ class Tacotron(BaseTacotron):
speaker_ids = batch["speaker_ids"]
d_vectors = batch["d_vectors"]
# forward pass model
outputs = self.forward(
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
)
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % self.decoder.r != 0:
@ -249,9 +306,6 @@ class Tacotron(BaseTacotron):
else:
alignment_lengths = mel_lengths // self.decoder.r
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion(
@ -262,6 +316,7 @@ class Tacotron(BaseTacotron):
outputs["stop_tokens"].float(),
stop_targets.float(),
stop_target_lengths,
outputs["capacitron_vae_outputs"] if self.capacitron_vae else None,
mel_lengths,
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
outputs["alignments"].float(),
@ -275,6 +330,25 @@ class Tacotron(BaseTacotron):
loss_dict["align_error"] = align_error
return outputs, loss_dict
def get_optimizer(self) -> List:
if self.use_capacitron_vae:
return CapacitronOptimizer(self.config, self.named_parameters())
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
def get_scheduler(self, optimizer: object):
opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt)
def before_gradient_clipping(self):
if self.use_capacitron_vae:
# Capacitron model specific gradient clipping
model_params_to_clip = []
for name, param in self.named_parameters():
if param.requires_grad:
if name != "capacitron_vae_layer.beta":
model_params_to_clip.append(param)
torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip)
def _create_logs(self, batch, outputs, ap):
postnet_outputs = outputs["model_outputs"]
decoder_outputs = outputs["decoder_outputs"]

View File

@ -5,7 +5,9 @@ from typing import Dict, List, Union
import torch
from torch import nn
from torch.cuda.amp.autocast_mode import autocast
from trainer.trainer_utils import get_optimizer, get_scheduler
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.base_tacotron import BaseTacotron
@ -13,6 +15,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.capacitron_optimizer import CapacitronOptimizer
class Tacotron2(BaseTacotron):
@ -65,6 +68,9 @@ class Tacotron2(BaseTacotron):
if self.use_gst:
self.decoder_in_features += self.gst.gst_embedding_dim
if self.use_capacitron_vae:
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim
# embedding layer
self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0)
@ -102,6 +108,20 @@ class Tacotron2(BaseTacotron):
gst_embedding_dim=self.gst.gst_embedding_dim,
)
# Capacitron VAE Layers
if self.capacitron_vae and self.use_capacitron_vae:
self.capacitron_vae_layer = CapacitronVAE(
num_mel=self.decoder_output_dim,
encoder_output_dim=self.encoder_in_features,
capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim,
speaker_embedding_dim=self.embedded_speaker_dim
if self.capacitron_vae.capacitron_use_speaker_embedding
else None,
text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim
if self.capacitron_vae.capacitron_use_text_summary_embeddings
else None,
)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
@ -166,6 +186,20 @@ class Tacotron2(BaseTacotron):
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
# capacitron
if self.capacitron_vae and self.use_capacitron_vae:
# B x capacitron_VAE_embedding_dim
encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding(
encoder_outputs,
reference_mel_info=[mel_specs, mel_lengths],
text_info=[embedded_inputs.transpose(1, 2), text_lengths]
if self.capacitron_vae.capacitron_use_text_summary_embeddings
else None,
speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None,
)
else:
capacitron_vae_outputs = None
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
@ -197,6 +231,7 @@ class Tacotron2(BaseTacotron):
"decoder_outputs": decoder_outputs,
"alignments": alignments,
"stop_tokens": stop_tokens,
"capacitron_vae_outputs": capacitron_vae_outputs,
}
)
return outputs
@ -217,6 +252,29 @@ class Tacotron2(BaseTacotron):
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
if self.capacitron_vae and self.use_capacitron_vae:
if aux_input["style_text"] is not None:
style_text_embedding = self.embedding(aux_input["style_text"])
style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to(
encoder_outputs.device
) # pylint: disable=not-callable
reference_mel_length = (
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
if aux_input["style_mel"] is not None
else None
) # pylint: disable=not-callable
# B x capacitron_VAE_embedding_dim
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
encoder_outputs,
reference_mel_info=[aux_input["style_mel"], reference_mel_length]
if aux_input["style_mel"] is not None
else None,
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
speaker_embedding=aux_input["d_vectors"]
if self.capacitron_vae.capacitron_use_speaker_embedding
else None,
)
if self.num_speakers > 1:
if not self.use_d_vector_file:
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
@ -242,6 +300,13 @@ class Tacotron2(BaseTacotron):
}
return outputs
def before_backward_pass(self, loss_dict, optimizer) -> None:
# Extracting custom training specific operations for capacitron
# from the trainer
if self.use_capacitron_vae:
loss_dict["capacitron_vae_beta_loss"].backward()
optimizer.first_step()
def train_step(self, batch: Dict, criterion: torch.nn.Module):
"""A single training step. Forward pass and loss computation.
@ -258,14 +323,8 @@ class Tacotron2(BaseTacotron):
speaker_ids = batch["speaker_ids"]
d_vectors = batch["d_vectors"]
# forward pass model
outputs = self.forward(
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
)
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % self.decoder.r != 0:
@ -275,9 +334,6 @@ class Tacotron2(BaseTacotron):
else:
alignment_lengths = mel_lengths // self.decoder.r
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
# compute loss
with autocast(enabled=False): # use float32 for the criterion
loss_dict = criterion(
@ -288,6 +344,7 @@ class Tacotron2(BaseTacotron):
outputs["stop_tokens"].float(),
stop_targets.float(),
stop_target_lengths,
outputs["capacitron_vae_outputs"] if self.capacitron_vae else None,
mel_lengths,
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
outputs["alignments"].float(),
@ -301,6 +358,25 @@ class Tacotron2(BaseTacotron):
loss_dict["align_error"] = align_error
return outputs, loss_dict
def get_optimizer(self) -> List:
if self.use_capacitron_vae:
return CapacitronOptimizer(self.config, self.named_parameters())
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
def get_scheduler(self, optimizer: object):
opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt)
def before_gradient_clipping(self):
if self.use_capacitron_vae:
# Capacitron model specific gradient clipping
model_params_to_clip = []
for name, param in self.named_parameters():
if param.requires_grad:
if name != "capacitron_vae_layer.beta":
model_params_to_clip.append(param)
torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip)
def _create_logs(self, batch, outputs, ap):
"""Create dashboard log information."""
postnet_outputs = outputs["model_outputs"]

View File

@ -26,6 +26,7 @@ def run_model_torch(
inputs: torch.Tensor,
speaker_id: int = None,
style_mel: torch.Tensor = None,
style_text: str = None,
d_vector: torch.Tensor = None,
language_id: torch.Tensor = None,
) -> Dict:
@ -53,6 +54,7 @@ def run_model_torch(
"speaker_ids": speaker_id,
"d_vectors": d_vector,
"style_mel": style_mel,
"style_text": style_text,
"language_ids": language_id,
},
)
@ -115,6 +117,7 @@ def synthesis(
use_cuda,
speaker_id=None,
style_wav=None,
style_text=None,
use_griffin_lim=False,
do_trim_silence=False,
d_vector=None,
@ -140,7 +143,12 @@ def synthesis(
Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None.
style_wav (str | Dict[str, float]):
Path or tensor to/of a waveform used for computing the style embedding. Defaults to None.
Path or tensor to/of a waveform used for computing the style embedding based on GST or Capacitron.
Defaults to None, meaning that Capacitron models will sample from the prior distribution to
generate random but realistic prosody.
style_text (str):
Transcription of style_wav for Capacitron models. Defaults to None.
enable_eos_bos_chars (bool):
enable special chars for end of sentence and start of sentence. Defaults to False.
@ -154,13 +162,19 @@ def synthesis(
language_id (int):
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
"""
# GST processing
# GST or Capacitron processing
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
style_mel = None
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
style_mel = style_mel.transpose(1, 2) # [1, time, depth]
# convert text to sequence of token IDs
text_inputs = np.asarray(
model.tokenizer.text_to_ids(text, language=language_id),
@ -177,11 +191,28 @@ def synthesis(
language_id = id_to_torch(language_id, cuda=use_cuda)
if not isinstance(style_mel, dict):
# GST or Capacitron style mel
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
if style_text is not None:
style_text = np.asarray(
model.tokenizer.text_to_ids(style_text, language=language_id),
dtype=np.int32,
)
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda)
style_text = style_text.unsqueeze(0)
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
text_inputs = text_inputs.unsqueeze(0)
# synthesize voice
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
outputs = run_model_torch(
model,
text_inputs,
speaker_id,
style_mel,
style_text,
d_vector=d_vector,
language_id=language_id,
)
model_outputs = outputs["model_outputs"]
model_outputs = model_outputs[0].data.cpu().numpy()
alignments = outputs["alignments"]

View File

@ -0,0 +1,65 @@
from typing import Generator
from trainer.trainer_utils import get_optimizer
class CapacitronOptimizer:
"""Double optimizer class for the Capacitron model."""
def __init__(self, config: dict, model_params: Generator) -> None:
self.primary_params, self.secondary_params = self.split_model_parameters(model_params)
optimizer_names = list(config.optimizer_params.keys())
optimizer_parameters = list(config.optimizer_params.values())
self.primary_optimizer = get_optimizer(
optimizer_names[0],
optimizer_parameters[0],
config.lr,
parameters=self.primary_params,
)
self.secondary_optimizer = get_optimizer(
optimizer_names[1],
self.extract_optimizer_parameters(optimizer_parameters[1]),
optimizer_parameters[1]["lr"],
parameters=self.secondary_params,
)
self.param_groups = self.primary_optimizer.param_groups
def first_step(self):
self.secondary_optimizer.step()
self.secondary_optimizer.zero_grad()
self.primary_optimizer.zero_grad()
def step(self):
self.primary_optimizer.step()
def zero_grad(self):
self.primary_optimizer.zero_grad()
self.secondary_optimizer.zero_grad()
def load_state_dict(self, state_dict):
self.primary_optimizer.load_state_dict(state_dict[0])
self.secondary_optimizer.load_state_dict(state_dict[1])
def state_dict(self):
return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()]
@staticmethod
def split_model_parameters(model_params: Generator) -> list:
primary_params = []
secondary_params = []
for name, param in model_params:
if param.requires_grad:
if name == "capacitron_vae_layer.beta":
secondary_params.append(param)
else:
primary_params.append(param)
return [iter(primary_params), iter(secondary_params)]
@staticmethod
def extract_optimizer_parameters(params: dict) -> dict:
"""Extract parameters that are not the learning rate"""
return {k: v for k, v in params.items() if k != "lr"}

View File

@ -106,6 +106,8 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
model_state = model.state_dict()
if isinstance(optimizer, list):
optimizer_state = [optim.state_dict() for optim in optimizer]
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
else:
optimizer_state = optimizer.state_dict() if optimizer is not None else None

View File

@ -1,5 +1,5 @@
import time
from typing import List, Union
from typing import List
import numpy as np
import pysbd
@ -178,8 +178,9 @@ class Synthesizer(object):
text: str = "",
speaker_name: str = "",
language_name: str = "",
speaker_wav: Union[str, List[str]] = None,
speaker_wav=None,
style_wav=None,
style_text=None,
reference_wav=None,
reference_speaker_name=None,
) -> List[int]:
@ -191,6 +192,7 @@ class Synthesizer(object):
language_name (str, optional): language id for multi-language models. Defaults to "".
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
style_wav ([type], optional): style waveform for GST. Defaults to None.
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
reference_speaker_name ([type], optional): spekaer id of reference waveform. Defaults to None.
Returns:
@ -273,10 +275,11 @@ class Synthesizer(object):
CONFIG=self.tts_config,
use_cuda=self.use_cuda,
speaker_id=speaker_id,
language_id=language_id,
style_wav=style_wav,
style_text=style_text,
use_griffin_lim=use_gl,
d_vector=speaker_embedding,
language_id=language_id,
)
waveform = outputs["wav"]
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()

View File

@ -0,0 +1,12 @@
# How to get the Blizzard 2013 Dataset
The Capacitron model is a variational encoder extension of standard Tacotron based models to model prosody.
To take full advantage of the model, it is advised to train the model with a dataset that contains a significant amount of prosodic information in the utterances. A tested candidate for such applications is the blizzard2013 dataset from the Blizzard Challenge, containing many hours of high quality audio book recordings.
To get a license and download link for this dataset, you need to visit the [website](https://www.cstr.ed.ac.uk/projects/blizzard/2013/lessac_blizzard2013/license.html) of the Centre for Speech Technology Research of the University of Edinburgh.
You get access to the raw dataset in a couple of days. There are a few preprocessing steps you need to do to be able to use the high fidelity dataset.
1. Get the forced time alignments for the blizzard dataset from [here](https://github.com/mueller91/tts_alignments).
2. Segment the high fidelity audio-book files based on the instructions [here](https://github.com/Tomiinek/Blizzard2013_Segmentation).

View File

@ -0,0 +1,101 @@
import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig, CapacitronVAEConfig
from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.tacotron import Tacotron
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
data_path = "/srv/data/"
# Using LJSpeech like dataset processing for the blizzard dataset
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=data_path)
audio_config = BaseAudioConfig(
sample_rate=24000,
do_trim_silence=True,
trim_db=60.0,
signal_norm=True,
mel_fmin=80.0,
mel_fmax=12000,
spec_gain=20.0,
log_func="np.log10",
ref_level_db=20,
preemphasis=0.0,
min_level_db=-100,
)
# Using the standard Capacitron config
capacitron_config = CapacitronVAEConfig(capacitron_VAE_loss_alpha=1.0)
config = TacotronConfig(
run_name="Blizzard-Capacitron-T1",
audio=audio_config,
capacitron_vae=capacitron_config,
use_capacitron_vae=True,
batch_size=128, # Tune this to your gpu
max_audio_len=6 * 24000, # Tune this to your gpu
min_audio_len=0.5 * 24000,
eval_batch_size=16,
num_loader_workers=12,
num_eval_loader_workers=8,
precompute_num_workers=24,
run_eval=True,
test_delay_epochs=5,
ga_alpha=0.0,
r=2,
optimizer="CapacitronOptimizer",
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
attention_type="graves",
attention_heads=5,
epochs=1000,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phonemizer="espeak",
phoneme_cache_path=os.path.join(data_path, "phoneme_cache"),
stopnet_pos_weight=15,
print_step=50,
print_eval=True,
mixed_precision=False,
output_path=output_path,
datasets=[dataset_config],
lr=1e-3,
lr_scheduler="StepwiseGradualLR",
lr_scheduler_params={"gradual_learning_rates": [[0, 1e-3], [2e4, 5e-4], [4e5, 3e-4], [6e4, 1e-4], [8e4, 5e-5]]},
scheduler_after_epoch=False, # scheduler doesn't work without this flag
# Need to experiment with these below for capacitron
loss_masking=False,
decoder_loss_alpha=1.0,
postnet_loss_alpha=1.0,
postnet_diff_spec_alpha=0.0,
decoder_diff_spec_alpha=0.0,
decoder_ssim_alpha=0.0,
postnet_ssim_alpha=0.0,
)
ap = AudioProcessor(**config.audio.to_dict())
tokenizer, config = TTSTokenizer.init_from_config(config)
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
model = Tacotron(config, ap, tokenizer, speaker_manager=None)
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
# 🚀
trainer.fit()

View File

@ -0,0 +1,117 @@
import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig, CapacitronVAEConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.tacotron2 import Tacotron2
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
data_path = "/srv/data/blizzard2013/segmented"
# Using LJSpeech like dataset processing for the blizzard dataset
dataset_config = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
path=data_path,
)
audio_config = BaseAudioConfig(
sample_rate=24000,
do_trim_silence=True,
trim_db=60.0,
signal_norm=True,
mel_fmin=80.0,
mel_fmax=12000,
spec_gain=25.0,
log_func="np.log10",
ref_level_db=20,
preemphasis=0.0,
min_level_db=-100,
)
# Using the standard Capacitron config
capacitron_config = CapacitronVAEConfig(capacitron_VAE_loss_alpha=1.0)
config = Tacotron2Config(
run_name="Blizzard-Capacitron-T2",
audio=audio_config,
capacitron_vae=capacitron_config,
use_capacitron_vae=True,
batch_size=246, # Tune this to your gpu
max_audio_len=6 * 24000, # Tune this to your gpu
min_audio_len=1 * 24000,
eval_batch_size=16,
num_loader_workers=12,
num_eval_loader_workers=8,
precompute_num_workers=24,
run_eval=True,
test_delay_epochs=5,
ga_alpha=0.0,
r=2,
optimizer="CapacitronOptimizer",
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
attention_type="dynamic_convolution",
grad_clip=0.0, # Important! We overwrite the standard grad_clip with capacitron_grad_clip
double_decoder_consistency=False,
epochs=1000,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phonemizer="espeak",
phoneme_cache_path=os.path.join(data_path, "phoneme_cache"),
stopnet_pos_weight=15,
print_step=25,
print_eval=True,
mixed_precision=False,
output_path=output_path,
datasets=[dataset_config],
lr=1e-3,
lr_scheduler="StepwiseGradualLR",
lr_scheduler_params={
"gradual_learning_rates": [
[0, 1e-3],
[2e4, 5e-4],
[4e5, 3e-4],
[6e4, 1e-4],
[8e4, 5e-5],
]
},
scheduler_after_epoch=False, # scheduler doesn't work without this flag
# dashboard_logger='wandb',
# sort_by_audio_len=True,
seq_len_norm=True,
# Need to experiment with these below for capacitron
loss_masking=False,
decoder_loss_alpha=1.0,
postnet_loss_alpha=1.0,
postnet_diff_spec_alpha=0.0,
decoder_diff_spec_alpha=0.0,
decoder_ssim_alpha=0.0,
postnet_ssim_alpha=0.0,
)
ap = AudioProcessor(**config.audio.to_dict())
tokenizer, config = TTSTokenizer.init_from_config(config)
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()

View File

@ -0,0 +1,115 @@
import os
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseAudioConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig, CapacitronVAEConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.tacotron2 import Tacotron2
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
data_path = "/srv/data/"
# Using LJSpeech like dataset processing for the blizzard dataset
dataset_config = BaseDatasetConfig(
name="ljspeech",
meta_file_train="metadata.csv",
path=data_path,
)
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=11025,
spec_gain=1.0,
log_func="np.log",
ref_level_db=20,
preemphasis=0.0,
)
# Using the standard Capacitron config
capacitron_config = CapacitronVAEConfig(capacitron_VAE_loss_alpha=1.0, capacitron_capacity=50)
config = Tacotron2Config(
run_name="Capacitron-Tacotron2",
audio=audio_config,
capacitron_vae=capacitron_config,
use_capacitron_vae=True,
batch_size=128, # Tune this to your gpu
max_audio_len=8 * 22050, # Tune this to your gpu
min_audio_len=1 * 22050,
eval_batch_size=16,
num_loader_workers=8,
num_eval_loader_workers=8,
precompute_num_workers=24,
run_eval=True,
test_delay_epochs=25,
ga_alpha=0.0,
r=2,
optimizer="CapacitronOptimizer",
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
attention_type="dynamic_convolution",
grad_clip=0.0, # Important! We overwrite the standard grad_clip with capacitron_grad_clip
double_decoder_consistency=False,
epochs=1000,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phonemizer="espeak",
phoneme_cache_path=os.path.join(data_path, "phoneme_cache"),
stopnet_pos_weight=15,
print_step=25,
print_eval=True,
mixed_precision=False,
sort_by_audio_len=True,
seq_len_norm=True,
output_path=output_path,
datasets=[dataset_config],
lr=1e-3,
lr_scheduler="StepwiseGradualLR",
lr_scheduler_params={
"gradual_learning_rates": [
[0, 1e-3],
[2e4, 5e-4],
[4e5, 3e-4],
[6e4, 1e-4],
[8e4, 5e-5],
]
},
scheduler_after_epoch=False, # scheduler doesn't work without this flag
# Need to experiment with these below for capacitron
loss_masking=False,
decoder_loss_alpha=1.0,
postnet_loss_alpha=1.0,
postnet_diff_spec_alpha=0.0,
decoder_diff_spec_alpha=0.0,
decoder_ssim_alpha=0.0,
postnet_ssim_alpha=0.0,
)
ap = AudioProcessor(**config.audio.to_dict())
tokenizer, config = TTSTokenizer.init_from_config(config)
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
trainer = Trainer(
TrainerArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()

View File

@ -6,7 +6,7 @@ import torch
from torch import nn, optim
from tests import get_tests_input_path
from TTS.tts.configs.shared_configs import GSTConfig
from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig
from TTS.tts.configs.tacotron2_config import Tacotron2Config
from TTS.tts.layers.losses import MSELossMasked
from TTS.tts.models.tacotron2 import Tacotron2
@ -260,6 +260,73 @@ class TacotronGSTTrainTest(unittest.TestCase):
count += 1
class TacotronCapacitronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = Tacotron2Config(
num_chars=32,
num_speakers=10,
use_speaker_embedding=True,
out_channels=80,
decoder_output_dim=80,
use_capacitron_vae=True,
capacitron_vae=CapacitronVAEConfig(),
optimizer="CapacitronOptimizer",
optimizer_params={
"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6},
"SGD": {"lr": 1e-5, "momentum": 0.9},
},
)
batch = dict({})
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
batch["text_lengths"][0] = 128
batch["mel_input"] = torch.rand(8, 120, config.audio["num_mels"]).to(device)
batch["mel_lengths"] = torch.randint(20, 120, (8,)).long().to(device)
batch["mel_lengths"] = torch.sort(batch["mel_lengths"], descending=True)[0]
batch["mel_lengths"][0] = 120
batch["stop_targets"] = torch.zeros(8, 120, 1).float().to(device)
batch["stop_target_lengths"] = torch.randint(0, 120, (8,)).to(device)
batch["speaker_ids"] = torch.randint(0, 5, (8,)).long().to(device)
batch["d_vectors"] = None
for idx in batch["mel_lengths"]:
batch["stop_targets"][:, int(idx.item()) :, 0] = 1.0
batch["stop_targets"] = batch["stop_targets"].view(
batch["text_input"].shape[0], batch["stop_targets"].size(1) // config.r, -1
)
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
model = Tacotron2(config).to(device)
criterion = model.get_criterion()
optimizer = model.get_optimizer()
model.train()
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
for _ in range(10):
_, loss_dict = model.train_step(batch, criterion)
optimizer.zero_grad()
loss_dict["capacitron_vae_beta_loss"].backward()
optimizer.first_step()
loss_dict["loss"].backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
"""Test multi-speaker Tacotron2 with Global Style Tokens and d-vector inputs."""

View File

@ -6,7 +6,7 @@ import torch
from torch import nn, optim
from tests import get_tests_input_path
from TTS.tts.configs.shared_configs import GSTConfig
from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig
from TTS.tts.configs.tacotron_config import TacotronConfig
from TTS.tts.layers.losses import L1LossMasked
from TTS.tts.models.tacotron import Tacotron
@ -248,6 +248,74 @@ class TacotronGSTTrainTest(unittest.TestCase):
count += 1
class TacotronCapacitronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():
config = TacotronConfig(
num_chars=32,
num_speakers=10,
use_speaker_embedding=True,
out_channels=513,
decoder_output_dim=80,
use_capacitron_vae=True,
capacitron_vae=CapacitronVAEConfig(),
optimizer="CapacitronOptimizer",
optimizer_params={
"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6},
"SGD": {"lr": 1e-5, "momentum": 0.9},
},
)
batch = dict({})
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
batch["text_lengths"][0] = 128
batch["linear_input"] = torch.rand(8, 120, config.audio["fft_size"] // 2 + 1).to(device)
batch["mel_input"] = torch.rand(8, 120, config.audio["num_mels"]).to(device)
batch["mel_lengths"] = torch.randint(20, 120, (8,)).long().to(device)
batch["mel_lengths"] = torch.sort(batch["mel_lengths"], descending=True)[0]
batch["mel_lengths"][0] = 120
batch["stop_targets"] = torch.zeros(8, 120, 1).float().to(device)
batch["stop_target_lengths"] = torch.randint(0, 120, (8,)).to(device)
batch["speaker_ids"] = torch.randint(0, 5, (8,)).long().to(device)
batch["d_vectors"] = None
for idx in batch["mel_lengths"]:
batch["stop_targets"][:, int(idx.item()) :, 0] = 1.0
batch["stop_targets"] = batch["stop_targets"].view(
batch["text_input"].shape[0], batch["stop_targets"].size(1) // config.r, -1
)
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
model = Tacotron(config).to(device)
criterion = model.get_criterion()
optimizer = model.get_optimizer()
model.train()
print(" > Num parameters for Tacotron with Capacitron VAE model:%s" % (count_parameters(model)))
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
assert (param - param_ref).sum() == 0, param
count += 1
for _ in range(10):
_, loss_dict = model.train_step(batch, criterion)
optimizer.zero_grad()
loss_dict["capacitron_vae_beta_loss"].backward()
optimizer.first_step()
loss_dict["loss"].backward()
optimizer.step()
# check parameter changes
count = 0
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
# ignore pre-higway layer since it works conditional
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
count, param.shape, param, param_ref
)
count += 1
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
@staticmethod
def test_train_step():