mirror of https://github.com/coqui-ai/TTS.git
Capacitron (#977)
* new CI config * initial Capacitron implementation * delete old unused file * fix empty formatting changes * update losses and training script * fix previous commit * fix commit * Add Capacitron test and first round of test fixes * revert formatter change * add changes to the synthesizer * add stepwise gradual lr scheduler and changes to the recipe * add inference script for dev use * feat: add posterior inference arguments to synth methods - added reference wav and text args for posterior inference - some formatting * fix: add espeak flag to base_tts and dataset APIs - use_espeak_phonemes flag was not implemented in those APIs - espeak is now able to be utilised for phoneme generation - necessary phonemizer for the Capacitron model * chore: update training script and style - training script includes the espeak flag and other hyperparams - made style * chore: fix linting * feat: add Tacotron 2 support * leftover from dev * chore:rename parser args * feat: extract optimizers - created a separate optimizer class to merge the two optimizers * chore: revert arbitrary trainer changes * fmt: revert formatting bug * formatting again * formatting fixed * fix: log func * fix: update optimizer - Implemented load_state_dict for continuing training * fix: clean optimizer init for standard models * improvement: purge espeak flags and add training scripts * Delete capacitronT2.py delete old training script, new one is pushed * feat: capacitron trainer methods - extracted capacitron specific training operations from the trainer into custom methods in taco1 and taco2 models * chore: renaming and merging capacitron and gst style args * fix: bug fixes from the previous commit * fix: implement state_dict method on CapacitronOptimizer * fix: call method * fix: inference naming * Delete train_capacitron.py * fix: synthesize * feat: update tests * chore: fix style * Delete capacitron_inference.py * fix: fix train tts t2 capacitron tests * fix: double forward in T2 train step * fix: double forward in T1 train step * fix: run make style * fix: remove unused import * fix: test for T1 capacitron * fix: make lint * feat: add blizzard2013 recipes * make style * fix: update recipes * chore: make style * Plot test sentences in Tacotron * chore: make style and fix import * fix: call forward first before problematic floordiv op * fix: update recipes * feat: add min_audio_len to recipes * aux_input["style_mel"] * chore: make style * Make capacitron T2 recipe more stable * Remove T1 capacitron Ljspeech * feat: implement new grad clipping routine and update configs * make style * Add pretrained checkpoints * Add default vocoder * Change trainer package * Fix grad clip issue for tacotron * Fix scheduler issue with tacotron Co-authored-by: Eren Gölge <egolge@coqui.ai> Co-authored-by: WeberJulian <julian.weber@hotmail.fr> Co-authored-by: Eren Gölge <erogol@hotmail.com>
This commit is contained in:
parent
ee99a6c1e2
commit
8be21ec387
|
@ -119,6 +119,26 @@
|
|||
"license": "apache 2.0",
|
||||
"contact": "egolge@coqui.com"
|
||||
}
|
||||
},
|
||||
"blizzard2013": {
|
||||
"capacitron-t2-c50": {
|
||||
"description": "Capacitron additions to Tacotron 2 with Capacity at 50 as in https://arxiv.org/pdf/1906.03402.pdf",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--en--blizzard2013--capacitron-t2-c50.zip",
|
||||
"commit": "d6284e7",
|
||||
"default_vocoder": "vocoder_models/en/blizzard2013/hifigan_v2",
|
||||
"author": "Adam Froghyar @a-froghyar",
|
||||
"license": "apache 2.0",
|
||||
"contact": "adamfroghyar@gmail.com"
|
||||
},
|
||||
"capacitron-t2-c150": {
|
||||
"description": "Capacitron additions to Tacotron 2 with Capacity at 150 as in https://arxiv.org/pdf/1906.03402.pdf",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/tts_models--en--blizzard2013--capacitron-t2-c150.zip",
|
||||
"commit": "d6284e7",
|
||||
"default_vocoder": "vocoder_models/en/blizzard2013/hifigan_v2",
|
||||
"author": "Adam Froghyar @a-froghyar",
|
||||
"license": "apache 2.0",
|
||||
"contact": "adamfroghyar@gmail.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
"es": {
|
||||
|
@ -379,6 +399,16 @@
|
|||
"contact": "egolge@coqui.ai"
|
||||
}
|
||||
},
|
||||
"blizzard2013": {
|
||||
"hifigan_v2": {
|
||||
"description": "HiFiGAN_v2 LJSpeech vocoder from https://arxiv.org/abs/2010.05646.",
|
||||
"github_rls_url": "https://coqui.gateway.scarf.sh/v0.7.0_models/vocoder_models--en--blizzard2013--hifigan_v2.zip",
|
||||
"commit": "d6284e7",
|
||||
"author": "Adam Froghyar @a-froghyar",
|
||||
"license": "apache 2.0",
|
||||
"contact": "adamfroghyar@gmail.com"
|
||||
}
|
||||
},
|
||||
"vctk": {
|
||||
"hifigan_v2": {
|
||||
"description": "Finetuned and intended to be used with tts_models/en/vctk/sc-glow-tts",
|
||||
|
|
|
@ -172,6 +172,10 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
default=None,
|
||||
)
|
||||
parser.add_argument("--gst_style", help="Wav path file for GST style reference.", default=None)
|
||||
parser.add_argument(
|
||||
"--capacitron_style_wav", type=str, help="Wav path file for Capacitron prosody reference.", default=None
|
||||
)
|
||||
parser.add_argument("--capacitron_style_text", type=str, help="Transcription of the reference.", default=None)
|
||||
parser.add_argument(
|
||||
"--list_speaker_idxs",
|
||||
help="List available speaker ids for the defined multi-speaker model.",
|
||||
|
@ -308,6 +312,8 @@ If you don't specify any models, then it uses LJSpeech based English model.
|
|||
args.language_idx,
|
||||
args.speaker_wav,
|
||||
reference_wav=args.reference_wav,
|
||||
style_wav=args.capacitron_style_wav,
|
||||
style_text=args.capacitron_style_text,
|
||||
reference_speaker_name=args.reference_speaker_idx,
|
||||
)
|
||||
|
||||
|
|
|
@ -48,6 +48,50 @@ class GSTConfig(Coqpit):
|
|||
check_argument("gst_num_style_tokens", c, restricted=True, min_val=1, max_val=1000)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapacitronVAEConfig(Coqpit):
|
||||
"""Defines the capacitron VAE Module
|
||||
Args:
|
||||
capacitron_capacity (int):
|
||||
Defines the variational capacity limit of the prosody embeddings. Defaults to 150.
|
||||
capacitron_VAE_embedding_dim (int):
|
||||
Defines the size of the Capacitron embedding vector dimension. Defaults to 128.
|
||||
capacitron_use_text_summary_embeddings (bool):
|
||||
If True, use a text summary embedding in Capacitron. Defaults to True.
|
||||
capacitron_text_summary_embedding_dim (int):
|
||||
Defines the size of the capacitron text embedding vector dimension. Defaults to 128.
|
||||
capacitron_use_speaker_embedding (bool):
|
||||
if True use speaker embeddings in Capacitron. Defaults to False.
|
||||
capacitron_VAE_loss_alpha (float):
|
||||
Weight for the VAE loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
capacitron_grad_clip (float):
|
||||
Gradient clipping value for all gradients except beta. Defaults to 5.0
|
||||
"""
|
||||
|
||||
capacitron_loss_alpha: int = 1
|
||||
capacitron_capacity: int = 150
|
||||
capacitron_VAE_embedding_dim: int = 128
|
||||
capacitron_use_text_summary_embeddings: bool = True
|
||||
capacitron_text_summary_embedding_dim: int = 128
|
||||
capacitron_use_speaker_embedding: bool = False
|
||||
capacitron_VAE_loss_alpha: float = 0.25
|
||||
capacitron_grad_clip: float = 5.0
|
||||
|
||||
def check_values(
|
||||
self,
|
||||
):
|
||||
"""Check config fields"""
|
||||
c = asdict(self)
|
||||
super().check_values()
|
||||
check_argument("capacitron_capacity", c, restricted=True, min_val=10, max_val=500)
|
||||
check_argument("capacitron_VAE_embedding_dim", c, restricted=True, min_val=16, max_val=1024)
|
||||
check_argument("capacitron_use_speaker_embedding", c, restricted=False)
|
||||
check_argument("capacitron_text_summary_embedding_dim", c, restricted=False, min_val=16, max_val=512)
|
||||
check_argument("capacitron_VAE_loss_alpha", c, restricted=False)
|
||||
check_argument("capacitron_grad_clip", c, restricted=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CharactersConfig(Coqpit):
|
||||
"""Defines arguments for the `BaseCharacters` or `BaseVocabulary` and their subclasses.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig, GSTConfig
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig, CapacitronVAEConfig, GSTConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -23,6 +23,10 @@ class TacotronConfig(BaseTTSConfig):
|
|||
gst_style_input (str):
|
||||
Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and
|
||||
this is not defined, the model uses a zero vector as an input. Defaults to None.
|
||||
use_capacitron_vae (bool):
|
||||
enable / disable the use of Capacitron modules. Defaults to False.
|
||||
capacitron_vae (CapacitronConfig):
|
||||
Instance of `CapacitronConfig` class.
|
||||
num_chars (int):
|
||||
Number of characters used by the model. It must be defined before initializing the model. Defaults to None.
|
||||
num_speakers (int):
|
||||
|
@ -143,6 +147,9 @@ class TacotronConfig(BaseTTSConfig):
|
|||
gst: GSTConfig = None
|
||||
gst_style_input: str = None
|
||||
|
||||
use_capacitron_vae: bool = False
|
||||
capacitron_vae: CapacitronVAEConfig = None
|
||||
|
||||
# model specific params
|
||||
num_speakers: int = 1
|
||||
num_chars: int = 0
|
||||
|
|
|
@ -281,6 +281,10 @@ class TacotronLoss(torch.nn.Module):
|
|||
def __init__(self, c, ga_sigma=0.4):
|
||||
super().__init__()
|
||||
self.stopnet_pos_weight = c.stopnet_pos_weight
|
||||
self.use_capacitron_vae = c.use_capacitron_vae
|
||||
if self.use_capacitron_vae:
|
||||
self.capacitron_capacity = c.capacitron_vae.capacitron_capacity
|
||||
self.capacitron_vae_loss_alpha = c.capacitron_vae.capacitron_VAE_loss_alpha
|
||||
self.ga_alpha = c.ga_alpha
|
||||
self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha
|
||||
self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha
|
||||
|
@ -308,6 +312,9 @@ class TacotronLoss(torch.nn.Module):
|
|||
# pylint: disable=not-callable
|
||||
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None
|
||||
|
||||
# For dev pruposes only
|
||||
self.criterion_capacitron_reconstruction_loss = nn.L1Loss(reduction="sum")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
postnet_output,
|
||||
|
@ -317,6 +324,7 @@ class TacotronLoss(torch.nn.Module):
|
|||
stopnet_output,
|
||||
stopnet_target,
|
||||
stop_target_length,
|
||||
capacitron_vae_outputs,
|
||||
output_lens,
|
||||
decoder_b_output,
|
||||
alignments,
|
||||
|
@ -348,6 +356,55 @@ class TacotronLoss(torch.nn.Module):
|
|||
return_dict["decoder_loss"] = decoder_loss
|
||||
return_dict["postnet_loss"] = postnet_loss
|
||||
|
||||
if self.use_capacitron_vae:
|
||||
# extract capacitron vae infos
|
||||
posterior_distribution, prior_distribution, beta = capacitron_vae_outputs
|
||||
|
||||
# KL divergence term between the posterior and the prior
|
||||
kl_term = torch.mean(torch.distributions.kl_divergence(posterior_distribution, prior_distribution))
|
||||
|
||||
# Limit the mutual information between the data and latent space by the variational capacity limit
|
||||
kl_capacity = kl_term - self.capacitron_capacity
|
||||
|
||||
# pass beta through softplus to keep it positive
|
||||
beta = torch.nn.functional.softplus(beta)[0]
|
||||
|
||||
# This is the term going to the main ADAM optimiser, we detach beta because
|
||||
# beta is optimised by a separate, SGD optimiser below
|
||||
capacitron_vae_loss = beta.detach() * kl_capacity
|
||||
|
||||
# normalize the capacitron_vae_loss as in L1Loss or MSELoss.
|
||||
# After this, both the standard loss and capacitron_vae_loss will be in the same scale.
|
||||
# For this reason we don't need use L1Loss and MSELoss in "sum" reduction mode.
|
||||
# Note: the batch is not considered because the L1Loss was calculated in "sum" mode
|
||||
# divided by the batch size, So not dividing the capacitron_vae_loss by B is legitimate.
|
||||
|
||||
# get B T D dimension from input
|
||||
B, T, D = mel_input.size()
|
||||
# normalize
|
||||
if self.config.loss_masking:
|
||||
# if mask loss get T using the mask
|
||||
T = output_lens.sum() / B
|
||||
|
||||
# Only for dev purposes to be able to compare the reconstruction loss with the values in the
|
||||
# original Capacitron paper
|
||||
return_dict["capaciton_reconstruction_loss"] = (
|
||||
self.criterion_capacitron_reconstruction_loss(decoder_output, mel_input) / decoder_output.size(0)
|
||||
) + kl_capacity
|
||||
|
||||
capacitron_vae_loss = capacitron_vae_loss / (T * D)
|
||||
capacitron_vae_loss = capacitron_vae_loss * self.capacitron_vae_loss_alpha
|
||||
|
||||
# This is the term to purely optimise beta and to pass into the SGD optimizer
|
||||
beta_loss = torch.negative(beta) * kl_capacity.detach()
|
||||
|
||||
loss += capacitron_vae_loss
|
||||
|
||||
return_dict["capacitron_vae_loss"] = capacitron_vae_loss
|
||||
return_dict["capacitron_vae_beta_loss"] = beta_loss
|
||||
return_dict["capacitron_vae_kl_term"] = kl_term
|
||||
return_dict["capacitron_beta"] = beta
|
||||
|
||||
stop_loss = (
|
||||
self.criterion_st(stopnet_output, stopnet_target, stop_target_length)
|
||||
if self.config.stopnet
|
||||
|
|
|
@ -0,0 +1,205 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.distributions.multivariate_normal import MultivariateNormal as MVN
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class CapacitronVAE(nn.Module):
|
||||
"""Effective Use of Variational Embedding Capacity for prosody transfer.
|
||||
|
||||
See https://arxiv.org/abs/1906.03402"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_mel,
|
||||
capacitron_VAE_embedding_dim,
|
||||
encoder_output_dim=256,
|
||||
reference_encoder_out_dim=128,
|
||||
speaker_embedding_dim=None,
|
||||
text_summary_embedding_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Init distributions
|
||||
self.prior_distribution = MVN(
|
||||
torch.zeros(capacitron_VAE_embedding_dim), torch.eye(capacitron_VAE_embedding_dim)
|
||||
)
|
||||
self.approximate_posterior_distribution = None
|
||||
# define output ReferenceEncoder dim to the capacitron_VAE_embedding_dim
|
||||
self.encoder = ReferenceEncoder(num_mel, out_dim=reference_encoder_out_dim)
|
||||
|
||||
# Init beta, the lagrange-like term for the KL distribution
|
||||
self.beta = torch.nn.Parameter(torch.log(torch.exp(torch.Tensor([1.0])) - 1), requires_grad=True)
|
||||
mlp_input_dimension = reference_encoder_out_dim
|
||||
|
||||
if text_summary_embedding_dim is not None:
|
||||
self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim)
|
||||
mlp_input_dimension += text_summary_embedding_dim
|
||||
if speaker_embedding_dim is not None:
|
||||
# TODO: Test a multispeaker model!
|
||||
mlp_input_dimension += speaker_embedding_dim
|
||||
self.post_encoder_mlp = PostEncoderMLP(mlp_input_dimension, capacitron_VAE_embedding_dim)
|
||||
|
||||
def forward(self, reference_mel_info=None, text_info=None, speaker_embedding=None):
|
||||
# Use reference
|
||||
if reference_mel_info is not None:
|
||||
reference_mels = reference_mel_info[0] # [batch_size, num_frames, num_mels]
|
||||
mel_lengths = reference_mel_info[1] # [batch_size]
|
||||
enc_out = self.encoder(reference_mels, mel_lengths)
|
||||
|
||||
# concat speaker_embedding and/or text summary embedding
|
||||
if text_info is not None:
|
||||
text_inputs = text_info[0] # [batch_size, num_characters, num_embedding]
|
||||
input_lengths = text_info[1]
|
||||
text_summary_out = self.text_summary_net(text_inputs, input_lengths).to(reference_mels.device)
|
||||
enc_out = torch.cat([enc_out, text_summary_out], dim=-1)
|
||||
if speaker_embedding is not None:
|
||||
enc_out = torch.cat([enc_out, speaker_embedding], dim=-1)
|
||||
|
||||
# Feed the output of the ref encoder and information about text/speaker into
|
||||
# an MLP to produce the parameteres for the approximate poterior distributions
|
||||
mu, sigma = self.post_encoder_mlp(enc_out)
|
||||
# convert to cpu because prior_distribution was created on cpu
|
||||
mu = mu.cpu()
|
||||
sigma = sigma.cpu()
|
||||
|
||||
# Sample from the posterior: z ~ q(z|x)
|
||||
self.approximate_posterior_distribution = MVN(mu, torch.diag_embed(sigma))
|
||||
VAE_embedding = self.approximate_posterior_distribution.rsample()
|
||||
# Infer from the model, bypasses encoding
|
||||
else:
|
||||
# Sample from the prior: z ~ p(z)
|
||||
VAE_embedding = self.prior_distribution.sample().unsqueeze(0)
|
||||
|
||||
# reshape to [batch_size, 1, capacitron_VAE_embedding_dim]
|
||||
return VAE_embedding.unsqueeze(1), self.approximate_posterior_distribution, self.prior_distribution, self.beta
|
||||
|
||||
|
||||
class ReferenceEncoder(nn.Module):
|
||||
"""NN module creating a fixed size prosody embedding from a spectrogram.
|
||||
|
||||
inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
|
||||
outputs: [batch_size, embedding_dim]
|
||||
"""
|
||||
|
||||
def __init__(self, num_mel, out_dim):
|
||||
|
||||
super().__init__()
|
||||
self.num_mel = num_mel
|
||||
filters = [1] + [32, 32, 64, 64, 128, 128]
|
||||
num_layers = len(filters) - 1
|
||||
convs = [
|
||||
nn.Conv2d(
|
||||
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(2, 2)
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.training = False
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
|
||||
|
||||
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers)
|
||||
self.recurrence = nn.LSTM(
|
||||
input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False
|
||||
)
|
||||
|
||||
def forward(self, inputs, input_lengths):
|
||||
batch_size = inputs.size(0)
|
||||
x = inputs.view(batch_size, 1, -1, self.num_mel) # [batch_size, num_channels==1, num_frames, num_mel]
|
||||
valid_lengths = input_lengths.float() # [batch_size]
|
||||
for conv, bn in zip(self.convs, self.bns):
|
||||
x = conv(x)
|
||||
x = bn(x)
|
||||
x = F.relu(x)
|
||||
|
||||
# Create the post conv width mask based on the valid lengths of the output of the convolution.
|
||||
# The valid lengths for the output of a convolution on varying length inputs is
|
||||
# ceil(input_length/stride) + 1 for stride=3 and padding=2
|
||||
# For example (kernel_size=3, stride=2, padding=2):
|
||||
# 0 0 x x x x x 0 0 -> Input = 5, 0 is zero padding, x is valid values coming from padding=2 in conv2d
|
||||
# _____
|
||||
# x _____
|
||||
# x _____
|
||||
# x ____
|
||||
# x
|
||||
# x x x x -> Output valid length = 4
|
||||
# Since every example in te batch is zero padded and therefore have separate valid_lengths,
|
||||
# we need to mask off all the values AFTER the valid length for each example in the batch.
|
||||
# Otherwise, the convolutions create noise and a lot of not real information
|
||||
valid_lengths = (valid_lengths / 2).float()
|
||||
valid_lengths = torch.ceil(valid_lengths).to(dtype=torch.int64) + 1 # 2 is stride -- size: [batch_size]
|
||||
post_conv_max_width = x.size(2)
|
||||
|
||||
mask = torch.arange(post_conv_max_width).to(inputs.device).expand(
|
||||
len(valid_lengths), post_conv_max_width
|
||||
) < valid_lengths.unsqueeze(1)
|
||||
mask = mask.expand(1, 1, -1, -1).transpose(2, 0).transpose(-1, 2) # [batch_size, 1, post_conv_max_width, 1]
|
||||
x = x * mask
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
# x: 4D tensor [batch_size, post_conv_width,
|
||||
# num_channels==128, post_conv_height]
|
||||
|
||||
post_conv_width = x.size(1)
|
||||
x = x.contiguous().view(batch_size, post_conv_width, -1)
|
||||
# x: 3D tensor [batch_size, post_conv_width,
|
||||
# num_channels*post_conv_height]
|
||||
|
||||
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
|
||||
post_conv_input_lengths = valid_lengths
|
||||
packed_seqs = nn.utils.rnn.pack_padded_sequence(
|
||||
x, post_conv_input_lengths.tolist(), batch_first=True, enforce_sorted=False
|
||||
) # dynamic rnn sequence padding
|
||||
self.recurrence.flatten_parameters()
|
||||
_, (ht, _) = self.recurrence(packed_seqs)
|
||||
last_output = ht[-1]
|
||||
|
||||
return last_output.to(inputs.device) # [B, 128]
|
||||
|
||||
@staticmethod
|
||||
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for _ in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
return height
|
||||
|
||||
|
||||
class TextSummary(nn.Module):
|
||||
def __init__(self, embedding_dim, encoder_output_dim):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(
|
||||
encoder_output_dim, # text embedding dimension from the text encoder
|
||||
embedding_dim, # fixed length output summary the lstm creates from the input
|
||||
batch_first=True,
|
||||
bidirectional=False,
|
||||
)
|
||||
|
||||
def forward(self, inputs, input_lengths):
|
||||
# Routine for fetching the last valid output of a dynamic LSTM with varying input lengths and padding
|
||||
packed_seqs = nn.utils.rnn.pack_padded_sequence(
|
||||
inputs, input_lengths.tolist(), batch_first=True, enforce_sorted=False
|
||||
) # dynamic rnn sequence padding
|
||||
self.lstm.flatten_parameters()
|
||||
_, (ht, _) = self.lstm(packed_seqs)
|
||||
last_output = ht[-1]
|
||||
return last_output
|
||||
|
||||
|
||||
class PostEncoderMLP(nn.Module):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
modules = [
|
||||
nn.Linear(input_size, hidden_size), # Hidden Layer
|
||||
nn.Tanh(),
|
||||
nn.Linear(hidden_size, hidden_size * 2),
|
||||
] # Output layer twice the size for mean and variance
|
||||
self.net = nn.Sequential(*modules)
|
||||
self.softplus = nn.Softplus()
|
||||
|
||||
def forward(self, _input):
|
||||
mlp_output = self.net(_input)
|
||||
# The mean parameter is unconstrained
|
||||
mu = mlp_output[:, : self.hidden_size]
|
||||
# The standard deviation must be positive. Parameterise with a softplus
|
||||
sigma = self.softplus(mlp_output[:, self.hidden_size :])
|
||||
return mu, sigma
|
|
@ -139,7 +139,7 @@ class MultiHeadAttention(nn.Module):
|
|||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
# score = softmax(QK^T / (d_k**0.5))
|
||||
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import copy
|
||||
from abc import abstractmethod
|
||||
from typing import Dict
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
@ -10,7 +10,9 @@ from TTS.tts.layers.losses import TacotronLoss
|
|||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.io import load_fsspec
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
@ -47,6 +49,11 @@ class BaseTacotron(BaseTTS):
|
|||
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
# Capacitron
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim # add capacitron embedding dim
|
||||
self.capacitron_vae_layer = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
@ -125,6 +132,53 @@ class BaseTacotron(BaseTTS):
|
|||
speaker_manager = SpeakerManager.init_from_config(config)
|
||||
return BaseTacotron(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
##########################
|
||||
# TEST AND LOG FUNCTIONS #
|
||||
##########################
|
||||
|
||||
def test_run(self, assets: Dict) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Args:
|
||||
assets (dict): A dict of training assets. For `tts` models, it must include `{'audio_processor': ap}`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_test_aux_input()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
outputs_dict = synthesis(
|
||||
self,
|
||||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
test_audios["{}-audio".format(idx)] = outputs_dict["wav"]
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||
outputs_dict["outputs"]["model_outputs"], self.ap, output_fig=False
|
||||
)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(
|
||||
outputs_dict["outputs"]["alignments"], output_fig=False
|
||||
)
|
||||
return {"figures": test_figures, "audios": test_audios}
|
||||
|
||||
def test_log(
|
||||
self, outputs: dict, logger: "Logger", assets: dict, steps: int # pylint: disable=unused-argument
|
||||
) -> None:
|
||||
logger.test_audios(steps, outputs["audios"], self.ap.sample_rate)
|
||||
logger.test_figures(steps, outputs["figures"])
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
@ -160,7 +214,9 @@ class BaseTacotron(BaseTTS):
|
|||
)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
|
||||
alignments_backward.transpose(1, 2),
|
||||
size=alignments.shape[1],
|
||||
mode="nearest",
|
||||
).transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
|
@ -193,6 +249,25 @@ class BaseTacotron(BaseTTS):
|
|||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None):
|
||||
"""Capacitron Variational Autoencoder"""
|
||||
(VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer(
|
||||
reference_mel_info,
|
||||
text_info,
|
||||
speaker_embedding, # pylint: disable=not-callable
|
||||
)
|
||||
|
||||
VAE_outputs = VAE_outputs.to(inputs.device)
|
||||
encoder_output = self._concat_speaker_embedding(
|
||||
inputs, VAE_outputs
|
||||
) # concatenate to the output of the basic tacotron encoder
|
||||
return (
|
||||
encoder_output,
|
||||
posterior_distribution,
|
||||
prior_distribution,
|
||||
capacitron_beta,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
|
@ -13,6 +15,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.capacitron_optimizer import CapacitronOptimizer
|
||||
|
||||
|
||||
class Tacotron(BaseTacotron):
|
||||
|
@ -51,6 +54,9 @@ class Tacotron(BaseTacotron):
|
|||
if self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim
|
||||
|
||||
if self.use_capacitron_vae:
|
||||
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
|
@ -90,6 +96,20 @@ class Tacotron(BaseTacotron):
|
|||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||
)
|
||||
|
||||
# Capacitron layers
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
self.capacitron_vae_layer = CapacitronVAE(
|
||||
num_mel=self.decoder_output_dim,
|
||||
encoder_output_dim=self.encoder_in_features,
|
||||
capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim,
|
||||
speaker_embedding_dim=self.embedded_speaker_dim
|
||||
if self.use_speaker_embedding and self.capacitron_vae.capacitron_use_speaker_embedding
|
||||
else None,
|
||||
text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim
|
||||
if self.capacitron_vae.capacitron_use_text_summary_embeddings
|
||||
else None,
|
||||
)
|
||||
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
|
@ -146,6 +166,19 @@ class Tacotron(BaseTacotron):
|
|||
# B x 1 x speaker_embed_dim
|
||||
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
# Capacitron
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
# B x capacitron_VAE_embedding_dim
|
||||
encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding(
|
||||
encoder_outputs,
|
||||
reference_mel_info=[mel_specs, mel_lengths],
|
||||
text_info=[inputs, text_lengths]
|
||||
if self.capacitron_vae.capacitron_use_text_summary_embeddings
|
||||
else None,
|
||||
speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None,
|
||||
)
|
||||
else:
|
||||
capacitron_vae_outputs = None
|
||||
# decoder_outputs: B x decoder_in_features x T_out
|
||||
# alignments: B x T_in x encoder_in_features
|
||||
# stop_tokens: B x T_in
|
||||
|
@ -178,6 +211,7 @@ class Tacotron(BaseTacotron):
|
|||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
"capacitron_vae_outputs": capacitron_vae_outputs,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -190,6 +224,28 @@ class Tacotron(BaseTacotron):
|
|||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
if aux_input["style_text"] is not None:
|
||||
style_text_embedding = self.embedding(aux_input["style_text"])
|
||||
style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to(
|
||||
encoder_outputs.device
|
||||
) # pylint: disable=not-callable
|
||||
reference_mel_length = (
|
||||
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
|
||||
if aux_input["style_mel"] is not None
|
||||
else None
|
||||
) # pylint: disable=not-callable
|
||||
# B x capacitron_VAE_embedding_dim
|
||||
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
|
||||
encoder_outputs,
|
||||
reference_mel_info=[aux_input["style_mel"], reference_mel_length]
|
||||
if aux_input["style_mel"] is not None
|
||||
else None,
|
||||
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
|
||||
speaker_embedding=aux_input["d_vectors"]
|
||||
if self.capacitron_vae.capacitron_use_speaker_embedding
|
||||
else None,
|
||||
)
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vector_file:
|
||||
# B x 1 x speaker_embed_dim
|
||||
|
@ -215,12 +271,19 @@ class Tacotron(BaseTacotron):
|
|||
}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch, criterion):
|
||||
"""Perform a single training step by fetching the right set if samples from the batch.
|
||||
def before_backward_pass(self, loss_dict, optimizer) -> None:
|
||||
# Extracting custom training specific operations for capacitron
|
||||
# from the trainer
|
||||
if self.use_capacitron_vae:
|
||||
loss_dict["capacitron_vae_beta_loss"].backward()
|
||||
optimizer.first_step()
|
||||
|
||||
def train_step(self, batch: Dict, criterion: torch.nn.Module) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step by fetching the right set of samples from the batch.
|
||||
|
||||
Args:
|
||||
batch ([type]): [description]
|
||||
criterion ([type]): [description]
|
||||
batch ([Dict]): A dictionary of input tensors.
|
||||
criterion ([torch.nn.Module]): Callable criterion to compute model loss.
|
||||
"""
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
|
@ -232,14 +295,8 @@ class Tacotron(BaseTacotron):
|
|||
speaker_ids = batch["speaker_ids"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
|
||||
# forward pass model
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
|
||||
)
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % self.decoder.r != 0:
|
||||
|
@ -249,9 +306,6 @@ class Tacotron(BaseTacotron):
|
|||
else:
|
||||
alignment_lengths = mel_lengths // self.decoder.r
|
||||
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
|
||||
# compute loss
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion(
|
||||
|
@ -262,6 +316,7 @@ class Tacotron(BaseTacotron):
|
|||
outputs["stop_tokens"].float(),
|
||||
stop_targets.float(),
|
||||
stop_target_lengths,
|
||||
outputs["capacitron_vae_outputs"] if self.capacitron_vae else None,
|
||||
mel_lengths,
|
||||
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
|
||||
outputs["alignments"].float(),
|
||||
|
@ -275,6 +330,25 @@ class Tacotron(BaseTacotron):
|
|||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
if self.use_capacitron_vae:
|
||||
return CapacitronOptimizer(self.config, self.named_parameters())
|
||||
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
|
||||
|
||||
def get_scheduler(self, optimizer: object):
|
||||
opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt)
|
||||
|
||||
def before_gradient_clipping(self):
|
||||
if self.use_capacitron_vae:
|
||||
# Capacitron model specific gradient clipping
|
||||
model_params_to_clip = []
|
||||
for name, param in self.named_parameters():
|
||||
if param.requires_grad:
|
||||
if name != "capacitron_vae_layer.beta":
|
||||
model_params_to_clip.append(param)
|
||||
torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip)
|
||||
|
||||
def _create_logs(self, batch, outputs, ap):
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
decoder_outputs = outputs["decoder_outputs"]
|
||||
|
|
|
@ -5,7 +5,9 @@ from typing import Dict, List, Union
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.cuda.amp.autocast_mode import autocast
|
||||
from trainer.trainer_utils import get_optimizer, get_scheduler
|
||||
|
||||
from TTS.tts.layers.tacotron.capacitron_layers import CapacitronVAE
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
|
@ -13,6 +15,7 @@ from TTS.tts.utils.measures import alignment_diagonal_score
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.capacitron_optimizer import CapacitronOptimizer
|
||||
|
||||
|
||||
class Tacotron2(BaseTacotron):
|
||||
|
@ -65,6 +68,9 @@ class Tacotron2(BaseTacotron):
|
|||
if self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim
|
||||
|
||||
if self.use_capacitron_vae:
|
||||
self.decoder_in_features += self.capacitron_vae.capacitron_VAE_embedding_dim
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0)
|
||||
|
||||
|
@ -102,6 +108,20 @@ class Tacotron2(BaseTacotron):
|
|||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||
)
|
||||
|
||||
# Capacitron VAE Layers
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
self.capacitron_vae_layer = CapacitronVAE(
|
||||
num_mel=self.decoder_output_dim,
|
||||
encoder_output_dim=self.encoder_in_features,
|
||||
capacitron_VAE_embedding_dim=self.capacitron_vae.capacitron_VAE_embedding_dim,
|
||||
speaker_embedding_dim=self.embedded_speaker_dim
|
||||
if self.capacitron_vae.capacitron_use_speaker_embedding
|
||||
else None,
|
||||
text_summary_embedding_dim=self.capacitron_vae.capacitron_text_summary_embedding_dim
|
||||
if self.capacitron_vae.capacitron_use_text_summary_embeddings
|
||||
else None,
|
||||
)
|
||||
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
|
@ -166,6 +186,20 @@ class Tacotron2(BaseTacotron):
|
|||
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
|
||||
# capacitron
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
# B x capacitron_VAE_embedding_dim
|
||||
encoder_outputs, *capacitron_vae_outputs = self.compute_capacitron_VAE_embedding(
|
||||
encoder_outputs,
|
||||
reference_mel_info=[mel_specs, mel_lengths],
|
||||
text_info=[embedded_inputs.transpose(1, 2), text_lengths]
|
||||
if self.capacitron_vae.capacitron_use_text_summary_embeddings
|
||||
else None,
|
||||
speaker_embedding=embedded_speakers if self.capacitron_vae.capacitron_use_speaker_embedding else None,
|
||||
)
|
||||
else:
|
||||
capacitron_vae_outputs = None
|
||||
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
|
||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||
|
@ -197,6 +231,7 @@ class Tacotron2(BaseTacotron):
|
|||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
"capacitron_vae_outputs": capacitron_vae_outputs,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
@ -217,6 +252,29 @@ class Tacotron2(BaseTacotron):
|
|||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
|
||||
if self.capacitron_vae and self.use_capacitron_vae:
|
||||
if aux_input["style_text"] is not None:
|
||||
style_text_embedding = self.embedding(aux_input["style_text"])
|
||||
style_text_length = torch.tensor([style_text_embedding.size(1)], dtype=torch.int64).to(
|
||||
encoder_outputs.device
|
||||
) # pylint: disable=not-callable
|
||||
reference_mel_length = (
|
||||
torch.tensor([aux_input["style_mel"].size(1)], dtype=torch.int64).to(encoder_outputs.device)
|
||||
if aux_input["style_mel"] is not None
|
||||
else None
|
||||
) # pylint: disable=not-callable
|
||||
# B x capacitron_VAE_embedding_dim
|
||||
encoder_outputs, *_ = self.compute_capacitron_VAE_embedding(
|
||||
encoder_outputs,
|
||||
reference_mel_info=[aux_input["style_mel"], reference_mel_length]
|
||||
if aux_input["style_mel"] is not None
|
||||
else None,
|
||||
text_info=[style_text_embedding, style_text_length] if aux_input["style_text"] is not None else None,
|
||||
speaker_embedding=aux_input["d_vectors"]
|
||||
if self.capacitron_vae.capacitron_use_speaker_embedding
|
||||
else None,
|
||||
)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
if not self.use_d_vector_file:
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
|
||||
|
@ -242,6 +300,13 @@ class Tacotron2(BaseTacotron):
|
|||
}
|
||||
return outputs
|
||||
|
||||
def before_backward_pass(self, loss_dict, optimizer) -> None:
|
||||
# Extracting custom training specific operations for capacitron
|
||||
# from the trainer
|
||||
if self.use_capacitron_vae:
|
||||
loss_dict["capacitron_vae_beta_loss"].backward()
|
||||
optimizer.first_step()
|
||||
|
||||
def train_step(self, batch: Dict, criterion: torch.nn.Module):
|
||||
"""A single training step. Forward pass and loss computation.
|
||||
|
||||
|
@ -258,14 +323,8 @@ class Tacotron2(BaseTacotron):
|
|||
speaker_ids = batch["speaker_ids"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
|
||||
# forward pass model
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
|
||||
)
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % self.decoder.r != 0:
|
||||
|
@ -275,9 +334,6 @@ class Tacotron2(BaseTacotron):
|
|||
else:
|
||||
alignment_lengths = mel_lengths // self.decoder.r
|
||||
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
|
||||
# compute loss
|
||||
with autocast(enabled=False): # use float32 for the criterion
|
||||
loss_dict = criterion(
|
||||
|
@ -288,6 +344,7 @@ class Tacotron2(BaseTacotron):
|
|||
outputs["stop_tokens"].float(),
|
||||
stop_targets.float(),
|
||||
stop_target_lengths,
|
||||
outputs["capacitron_vae_outputs"] if self.capacitron_vae else None,
|
||||
mel_lengths,
|
||||
None if outputs["decoder_outputs_backward"] is None else outputs["decoder_outputs_backward"].float(),
|
||||
outputs["alignments"].float(),
|
||||
|
@ -301,6 +358,25 @@ class Tacotron2(BaseTacotron):
|
|||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def get_optimizer(self) -> List:
|
||||
if self.use_capacitron_vae:
|
||||
return CapacitronOptimizer(self.config, self.named_parameters())
|
||||
return get_optimizer(self.config.optimizer, self.config.optimizer_params, self.config.lr, self)
|
||||
|
||||
def get_scheduler(self, optimizer: object):
|
||||
opt = optimizer.primary_optimizer if self.use_capacitron_vae else optimizer
|
||||
return get_scheduler(self.config.lr_scheduler, self.config.lr_scheduler_params, opt)
|
||||
|
||||
def before_gradient_clipping(self):
|
||||
if self.use_capacitron_vae:
|
||||
# Capacitron model specific gradient clipping
|
||||
model_params_to_clip = []
|
||||
for name, param in self.named_parameters():
|
||||
if param.requires_grad:
|
||||
if name != "capacitron_vae_layer.beta":
|
||||
model_params_to_clip.append(param)
|
||||
torch.nn.utils.clip_grad_norm_(model_params_to_clip, self.capacitron_vae.capacitron_grad_clip)
|
||||
|
||||
def _create_logs(self, batch, outputs, ap):
|
||||
"""Create dashboard log information."""
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
|
|
|
@ -26,6 +26,7 @@ def run_model_torch(
|
|||
inputs: torch.Tensor,
|
||||
speaker_id: int = None,
|
||||
style_mel: torch.Tensor = None,
|
||||
style_text: str = None,
|
||||
d_vector: torch.Tensor = None,
|
||||
language_id: torch.Tensor = None,
|
||||
) -> Dict:
|
||||
|
@ -53,6 +54,7 @@ def run_model_torch(
|
|||
"speaker_ids": speaker_id,
|
||||
"d_vectors": d_vector,
|
||||
"style_mel": style_mel,
|
||||
"style_text": style_text,
|
||||
"language_ids": language_id,
|
||||
},
|
||||
)
|
||||
|
@ -115,6 +117,7 @@ def synthesis(
|
|||
use_cuda,
|
||||
speaker_id=None,
|
||||
style_wav=None,
|
||||
style_text=None,
|
||||
use_griffin_lim=False,
|
||||
do_trim_silence=False,
|
||||
d_vector=None,
|
||||
|
@ -140,7 +143,12 @@ def synthesis(
|
|||
Speaker ID passed to the speaker embedding layer in multi-speaker model. Defaults to None.
|
||||
|
||||
style_wav (str | Dict[str, float]):
|
||||
Path or tensor to/of a waveform used for computing the style embedding. Defaults to None.
|
||||
Path or tensor to/of a waveform used for computing the style embedding based on GST or Capacitron.
|
||||
Defaults to None, meaning that Capacitron models will sample from the prior distribution to
|
||||
generate random but realistic prosody.
|
||||
|
||||
style_text (str):
|
||||
Transcription of style_wav for Capacitron models. Defaults to None.
|
||||
|
||||
enable_eos_bos_chars (bool):
|
||||
enable special chars for end of sentence and start of sentence. Defaults to False.
|
||||
|
@ -154,13 +162,19 @@ def synthesis(
|
|||
language_id (int):
|
||||
Language ID passed to the language embedding layer in multi-langual model. Defaults to None.
|
||||
"""
|
||||
# GST processing
|
||||
# GST or Capacitron processing
|
||||
# TODO: need to handle the case of setting both gst and capacitron to true somewhere
|
||||
style_mel = None
|
||||
if CONFIG.has("gst") and CONFIG.gst and style_wav is not None:
|
||||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
|
||||
|
||||
if CONFIG.has("capacitron_vae") and CONFIG.use_capacitron_vae and style_wav is not None:
|
||||
style_mel = compute_style_mel(style_wav, model.ap, cuda=use_cuda)
|
||||
style_mel = style_mel.transpose(1, 2) # [1, time, depth]
|
||||
|
||||
# convert text to sequence of token IDs
|
||||
text_inputs = np.asarray(
|
||||
model.tokenizer.text_to_ids(text, language=language_id),
|
||||
|
@ -177,11 +191,28 @@ def synthesis(
|
|||
language_id = id_to_torch(language_id, cuda=use_cuda)
|
||||
|
||||
if not isinstance(style_mel, dict):
|
||||
# GST or Capacitron style mel
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
if style_text is not None:
|
||||
style_text = np.asarray(
|
||||
model.tokenizer.text_to_ids(style_text, language=language_id),
|
||||
dtype=np.int32,
|
||||
)
|
||||
style_text = numpy_to_torch(style_text, torch.long, cuda=use_cuda)
|
||||
style_text = style_text.unsqueeze(0)
|
||||
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
# synthesize voice
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector, language_id=language_id)
|
||||
outputs = run_model_torch(
|
||||
model,
|
||||
text_inputs,
|
||||
speaker_id,
|
||||
style_mel,
|
||||
style_text,
|
||||
d_vector=d_vector,
|
||||
language_id=language_id,
|
||||
)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
from typing import Generator
|
||||
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
|
||||
|
||||
class CapacitronOptimizer:
|
||||
"""Double optimizer class for the Capacitron model."""
|
||||
|
||||
def __init__(self, config: dict, model_params: Generator) -> None:
|
||||
self.primary_params, self.secondary_params = self.split_model_parameters(model_params)
|
||||
|
||||
optimizer_names = list(config.optimizer_params.keys())
|
||||
optimizer_parameters = list(config.optimizer_params.values())
|
||||
|
||||
self.primary_optimizer = get_optimizer(
|
||||
optimizer_names[0],
|
||||
optimizer_parameters[0],
|
||||
config.lr,
|
||||
parameters=self.primary_params,
|
||||
)
|
||||
|
||||
self.secondary_optimizer = get_optimizer(
|
||||
optimizer_names[1],
|
||||
self.extract_optimizer_parameters(optimizer_parameters[1]),
|
||||
optimizer_parameters[1]["lr"],
|
||||
parameters=self.secondary_params,
|
||||
)
|
||||
|
||||
self.param_groups = self.primary_optimizer.param_groups
|
||||
|
||||
def first_step(self):
|
||||
self.secondary_optimizer.step()
|
||||
self.secondary_optimizer.zero_grad()
|
||||
self.primary_optimizer.zero_grad()
|
||||
|
||||
def step(self):
|
||||
self.primary_optimizer.step()
|
||||
|
||||
def zero_grad(self):
|
||||
self.primary_optimizer.zero_grad()
|
||||
self.secondary_optimizer.zero_grad()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.primary_optimizer.load_state_dict(state_dict[0])
|
||||
self.secondary_optimizer.load_state_dict(state_dict[1])
|
||||
|
||||
def state_dict(self):
|
||||
return [self.primary_optimizer.state_dict(), self.secondary_optimizer.state_dict()]
|
||||
|
||||
@staticmethod
|
||||
def split_model_parameters(model_params: Generator) -> list:
|
||||
primary_params = []
|
||||
secondary_params = []
|
||||
for name, param in model_params:
|
||||
if param.requires_grad:
|
||||
if name == "capacitron_vae_layer.beta":
|
||||
secondary_params.append(param)
|
||||
else:
|
||||
primary_params.append(param)
|
||||
return [iter(primary_params), iter(secondary_params)]
|
||||
|
||||
@staticmethod
|
||||
def extract_optimizer_parameters(params: dict) -> dict:
|
||||
"""Extract parameters that are not the learning rate"""
|
||||
return {k: v for k, v in params.items() if k != "lr"}
|
|
@ -106,6 +106,8 @@ def save_model(config, model, optimizer, scaler, current_step, epoch, output_pat
|
|||
model_state = model.state_dict()
|
||||
if isinstance(optimizer, list):
|
||||
optimizer_state = [optim.state_dict() for optim in optimizer]
|
||||
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
|
||||
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
|
||||
else:
|
||||
optimizer_state = optimizer.state_dict() if optimizer is not None else None
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import time
|
||||
from typing import List, Union
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pysbd
|
||||
|
@ -178,8 +178,9 @@ class Synthesizer(object):
|
|||
text: str = "",
|
||||
speaker_name: str = "",
|
||||
language_name: str = "",
|
||||
speaker_wav: Union[str, List[str]] = None,
|
||||
speaker_wav=None,
|
||||
style_wav=None,
|
||||
style_text=None,
|
||||
reference_wav=None,
|
||||
reference_speaker_name=None,
|
||||
) -> List[int]:
|
||||
|
@ -191,6 +192,7 @@ class Synthesizer(object):
|
|||
language_name (str, optional): language id for multi-language models. Defaults to "".
|
||||
speaker_wav (Union[str, List[str]], optional): path to the speaker wav. Defaults to None.
|
||||
style_wav ([type], optional): style waveform for GST. Defaults to None.
|
||||
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
|
||||
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
|
||||
reference_speaker_name ([type], optional): spekaer id of reference waveform. Defaults to None.
|
||||
Returns:
|
||||
|
@ -273,10 +275,11 @@ class Synthesizer(object):
|
|||
CONFIG=self.tts_config,
|
||||
use_cuda=self.use_cuda,
|
||||
speaker_id=speaker_id,
|
||||
language_id=language_id,
|
||||
style_wav=style_wav,
|
||||
style_text=style_text,
|
||||
use_griffin_lim=use_gl,
|
||||
d_vector=speaker_embedding,
|
||||
language_id=language_id,
|
||||
)
|
||||
waveform = outputs["wav"]
|
||||
mel_postnet_spec = outputs["outputs"]["model_outputs"][0].detach().cpu().numpy()
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
# How to get the Blizzard 2013 Dataset
|
||||
|
||||
The Capacitron model is a variational encoder extension of standard Tacotron based models to model prosody.
|
||||
|
||||
To take full advantage of the model, it is advised to train the model with a dataset that contains a significant amount of prosodic information in the utterances. A tested candidate for such applications is the blizzard2013 dataset from the Blizzard Challenge, containing many hours of high quality audio book recordings.
|
||||
|
||||
To get a license and download link for this dataset, you need to visit the [website](https://www.cstr.ed.ac.uk/projects/blizzard/2013/lessac_blizzard2013/license.html) of the Centre for Speech Technology Research of the University of Edinburgh.
|
||||
|
||||
You get access to the raw dataset in a couple of days. There are a few preprocessing steps you need to do to be able to use the high fidelity dataset.
|
||||
|
||||
1. Get the forced time alignments for the blizzard dataset from [here](https://github.com/mueller91/tts_alignments).
|
||||
2. Segment the high fidelity audio-book files based on the instructions [here](https://github.com/Tomiinek/Blizzard2013_Segmentation).
|
|
@ -0,0 +1,101 @@
|
|||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig, CapacitronVAEConfig
|
||||
from TTS.tts.configs.tacotron_config import TacotronConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.tacotron import Tacotron
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
data_path = "/srv/data/"
|
||||
|
||||
# Using LJSpeech like dataset processing for the blizzard dataset
|
||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=data_path)
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=24000,
|
||||
do_trim_silence=True,
|
||||
trim_db=60.0,
|
||||
signal_norm=True,
|
||||
mel_fmin=80.0,
|
||||
mel_fmax=12000,
|
||||
spec_gain=20.0,
|
||||
log_func="np.log10",
|
||||
ref_level_db=20,
|
||||
preemphasis=0.0,
|
||||
min_level_db=-100,
|
||||
)
|
||||
|
||||
# Using the standard Capacitron config
|
||||
capacitron_config = CapacitronVAEConfig(capacitron_VAE_loss_alpha=1.0)
|
||||
|
||||
config = TacotronConfig(
|
||||
run_name="Blizzard-Capacitron-T1",
|
||||
audio=audio_config,
|
||||
capacitron_vae=capacitron_config,
|
||||
use_capacitron_vae=True,
|
||||
batch_size=128, # Tune this to your gpu
|
||||
max_audio_len=6 * 24000, # Tune this to your gpu
|
||||
min_audio_len=0.5 * 24000,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=12,
|
||||
num_eval_loader_workers=8,
|
||||
precompute_num_workers=24,
|
||||
run_eval=True,
|
||||
test_delay_epochs=5,
|
||||
ga_alpha=0.0,
|
||||
r=2,
|
||||
optimizer="CapacitronOptimizer",
|
||||
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
||||
attention_type="graves",
|
||||
attention_heads=5,
|
||||
epochs=1000,
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phonemizer="espeak",
|
||||
phoneme_cache_path=os.path.join(data_path, "phoneme_cache"),
|
||||
stopnet_pos_weight=15,
|
||||
print_step=50,
|
||||
print_eval=True,
|
||||
mixed_precision=False,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
lr=1e-3,
|
||||
lr_scheduler="StepwiseGradualLR",
|
||||
lr_scheduler_params={"gradual_learning_rates": [[0, 1e-3], [2e4, 5e-4], [4e5, 3e-4], [6e4, 1e-4], [8e4, 5e-5]]},
|
||||
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
||||
# Need to experiment with these below for capacitron
|
||||
loss_masking=False,
|
||||
decoder_loss_alpha=1.0,
|
||||
postnet_loss_alpha=1.0,
|
||||
postnet_diff_spec_alpha=0.0,
|
||||
decoder_diff_spec_alpha=0.0,
|
||||
decoder_ssim_alpha=0.0,
|
||||
postnet_ssim_alpha=0.0,
|
||||
)
|
||||
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
model = Tacotron(config, ap, tokenizer, speaker_manager=None)
|
||||
|
||||
trainer = Trainer(
|
||||
TrainerArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
)
|
||||
|
||||
# 🚀
|
||||
trainer.fit()
|
|
@ -0,0 +1,117 @@
|
|||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig, CapacitronVAEConfig
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.tacotron2 import Tacotron2
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
data_path = "/srv/data/blizzard2013/segmented"
|
||||
|
||||
# Using LJSpeech like dataset processing for the blizzard dataset
|
||||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
path=data_path,
|
||||
)
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=24000,
|
||||
do_trim_silence=True,
|
||||
trim_db=60.0,
|
||||
signal_norm=True,
|
||||
mel_fmin=80.0,
|
||||
mel_fmax=12000,
|
||||
spec_gain=25.0,
|
||||
log_func="np.log10",
|
||||
ref_level_db=20,
|
||||
preemphasis=0.0,
|
||||
min_level_db=-100,
|
||||
)
|
||||
|
||||
# Using the standard Capacitron config
|
||||
capacitron_config = CapacitronVAEConfig(capacitron_VAE_loss_alpha=1.0)
|
||||
|
||||
config = Tacotron2Config(
|
||||
run_name="Blizzard-Capacitron-T2",
|
||||
audio=audio_config,
|
||||
capacitron_vae=capacitron_config,
|
||||
use_capacitron_vae=True,
|
||||
batch_size=246, # Tune this to your gpu
|
||||
max_audio_len=6 * 24000, # Tune this to your gpu
|
||||
min_audio_len=1 * 24000,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=12,
|
||||
num_eval_loader_workers=8,
|
||||
precompute_num_workers=24,
|
||||
run_eval=True,
|
||||
test_delay_epochs=5,
|
||||
ga_alpha=0.0,
|
||||
r=2,
|
||||
optimizer="CapacitronOptimizer",
|
||||
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
||||
attention_type="dynamic_convolution",
|
||||
grad_clip=0.0, # Important! We overwrite the standard grad_clip with capacitron_grad_clip
|
||||
double_decoder_consistency=False,
|
||||
epochs=1000,
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phonemizer="espeak",
|
||||
phoneme_cache_path=os.path.join(data_path, "phoneme_cache"),
|
||||
stopnet_pos_weight=15,
|
||||
print_step=25,
|
||||
print_eval=True,
|
||||
mixed_precision=False,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
lr=1e-3,
|
||||
lr_scheduler="StepwiseGradualLR",
|
||||
lr_scheduler_params={
|
||||
"gradual_learning_rates": [
|
||||
[0, 1e-3],
|
||||
[2e4, 5e-4],
|
||||
[4e5, 3e-4],
|
||||
[6e4, 1e-4],
|
||||
[8e4, 5e-5],
|
||||
]
|
||||
},
|
||||
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
||||
# dashboard_logger='wandb',
|
||||
# sort_by_audio_len=True,
|
||||
seq_len_norm=True,
|
||||
# Need to experiment with these below for capacitron
|
||||
loss_masking=False,
|
||||
decoder_loss_alpha=1.0,
|
||||
postnet_loss_alpha=1.0,
|
||||
postnet_diff_spec_alpha=0.0,
|
||||
decoder_diff_spec_alpha=0.0,
|
||||
decoder_ssim_alpha=0.0,
|
||||
postnet_ssim_alpha=0.0,
|
||||
)
|
||||
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
|
||||
|
||||
trainer = Trainer(
|
||||
TrainerArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
|
||||
trainer.fit()
|
|
@ -0,0 +1,115 @@
|
|||
import os
|
||||
|
||||
from trainer import Trainer, TrainerArgs
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig, CapacitronVAEConfig
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.tacotron2 import Tacotron2
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
data_path = "/srv/data/"
|
||||
|
||||
# Using LJSpeech like dataset processing for the blizzard dataset
|
||||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech",
|
||||
meta_file_train="metadata.csv",
|
||||
path=data_path,
|
||||
)
|
||||
|
||||
audio_config = BaseAudioConfig(
|
||||
sample_rate=22050,
|
||||
do_trim_silence=True,
|
||||
trim_db=60.0,
|
||||
signal_norm=False,
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=11025,
|
||||
spec_gain=1.0,
|
||||
log_func="np.log",
|
||||
ref_level_db=20,
|
||||
preemphasis=0.0,
|
||||
)
|
||||
|
||||
# Using the standard Capacitron config
|
||||
capacitron_config = CapacitronVAEConfig(capacitron_VAE_loss_alpha=1.0, capacitron_capacity=50)
|
||||
|
||||
config = Tacotron2Config(
|
||||
run_name="Capacitron-Tacotron2",
|
||||
audio=audio_config,
|
||||
capacitron_vae=capacitron_config,
|
||||
use_capacitron_vae=True,
|
||||
batch_size=128, # Tune this to your gpu
|
||||
max_audio_len=8 * 22050, # Tune this to your gpu
|
||||
min_audio_len=1 * 22050,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=8,
|
||||
num_eval_loader_workers=8,
|
||||
precompute_num_workers=24,
|
||||
run_eval=True,
|
||||
test_delay_epochs=25,
|
||||
ga_alpha=0.0,
|
||||
r=2,
|
||||
optimizer="CapacitronOptimizer",
|
||||
optimizer_params={"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6}, "SGD": {"lr": 1e-5, "momentum": 0.9}},
|
||||
attention_type="dynamic_convolution",
|
||||
grad_clip=0.0, # Important! We overwrite the standard grad_clip with capacitron_grad_clip
|
||||
double_decoder_consistency=False,
|
||||
epochs=1000,
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phonemizer="espeak",
|
||||
phoneme_cache_path=os.path.join(data_path, "phoneme_cache"),
|
||||
stopnet_pos_weight=15,
|
||||
print_step=25,
|
||||
print_eval=True,
|
||||
mixed_precision=False,
|
||||
sort_by_audio_len=True,
|
||||
seq_len_norm=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
lr=1e-3,
|
||||
lr_scheduler="StepwiseGradualLR",
|
||||
lr_scheduler_params={
|
||||
"gradual_learning_rates": [
|
||||
[0, 1e-3],
|
||||
[2e4, 5e-4],
|
||||
[4e5, 3e-4],
|
||||
[6e4, 1e-4],
|
||||
[8e4, 5e-5],
|
||||
]
|
||||
},
|
||||
scheduler_after_epoch=False, # scheduler doesn't work without this flag
|
||||
# Need to experiment with these below for capacitron
|
||||
loss_masking=False,
|
||||
decoder_loss_alpha=1.0,
|
||||
postnet_loss_alpha=1.0,
|
||||
postnet_diff_spec_alpha=0.0,
|
||||
decoder_diff_spec_alpha=0.0,
|
||||
decoder_ssim_alpha=0.0,
|
||||
postnet_ssim_alpha=0.0,
|
||||
)
|
||||
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
tokenizer, config = TTSTokenizer.init_from_config(config)
|
||||
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
model = Tacotron2(config, ap, tokenizer, speaker_manager=None)
|
||||
|
||||
trainer = Trainer(
|
||||
TrainerArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
|
||||
trainer.fit()
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from torch import nn, optim
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.tts.configs.shared_configs import GSTConfig
|
||||
from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig
|
||||
from TTS.tts.configs.tacotron2_config import Tacotron2Config
|
||||
from TTS.tts.layers.losses import MSELossMasked
|
||||
from TTS.tts.models.tacotron2 import Tacotron2
|
||||
|
@ -260,6 +260,73 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
count += 1
|
||||
|
||||
|
||||
class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
config = Tacotron2Config(
|
||||
num_chars=32,
|
||||
num_speakers=10,
|
||||
use_speaker_embedding=True,
|
||||
out_channels=80,
|
||||
decoder_output_dim=80,
|
||||
use_capacitron_vae=True,
|
||||
capacitron_vae=CapacitronVAEConfig(),
|
||||
optimizer="CapacitronOptimizer",
|
||||
optimizer_params={
|
||||
"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6},
|
||||
"SGD": {"lr": 1e-5, "momentum": 0.9},
|
||||
},
|
||||
)
|
||||
|
||||
batch = dict({})
|
||||
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
||||
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
||||
batch["text_lengths"][0] = 128
|
||||
batch["mel_input"] = torch.rand(8, 120, config.audio["num_mels"]).to(device)
|
||||
batch["mel_lengths"] = torch.randint(20, 120, (8,)).long().to(device)
|
||||
batch["mel_lengths"] = torch.sort(batch["mel_lengths"], descending=True)[0]
|
||||
batch["mel_lengths"][0] = 120
|
||||
batch["stop_targets"] = torch.zeros(8, 120, 1).float().to(device)
|
||||
batch["stop_target_lengths"] = torch.randint(0, 120, (8,)).to(device)
|
||||
batch["speaker_ids"] = torch.randint(0, 5, (8,)).long().to(device)
|
||||
batch["d_vectors"] = None
|
||||
|
||||
for idx in batch["mel_lengths"]:
|
||||
batch["stop_targets"][:, int(idx.item()) :, 0] = 1.0
|
||||
|
||||
batch["stop_targets"] = batch["stop_targets"].view(
|
||||
batch["text_input"].shape[0], batch["stop_targets"].size(1) // config.r, -1
|
||||
)
|
||||
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
model = Tacotron2(config).to(device)
|
||||
criterion = model.get_criterion()
|
||||
optimizer = model.get_optimizer()
|
||||
|
||||
model.train()
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
for _ in range(10):
|
||||
_, loss_dict = model.train_step(batch, criterion)
|
||||
optimizer.zero_grad()
|
||||
loss_dict["capacitron_vae_beta_loss"].backward()
|
||||
optimizer.first_step()
|
||||
loss_dict["loss"].backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||
"""Test multi-speaker Tacotron2 with Global Style Tokens and d-vector inputs."""
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from torch import nn, optim
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.tts.configs.shared_configs import GSTConfig
|
||||
from TTS.tts.configs.shared_configs import CapacitronVAEConfig, GSTConfig
|
||||
from TTS.tts.configs.tacotron_config import TacotronConfig
|
||||
from TTS.tts.layers.losses import L1LossMasked
|
||||
from TTS.tts.models.tacotron import Tacotron
|
||||
|
@ -248,6 +248,74 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
count += 1
|
||||
|
||||
|
||||
class TacotronCapacitronTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
config = TacotronConfig(
|
||||
num_chars=32,
|
||||
num_speakers=10,
|
||||
use_speaker_embedding=True,
|
||||
out_channels=513,
|
||||
decoder_output_dim=80,
|
||||
use_capacitron_vae=True,
|
||||
capacitron_vae=CapacitronVAEConfig(),
|
||||
optimizer="CapacitronOptimizer",
|
||||
optimizer_params={
|
||||
"RAdam": {"betas": [0.9, 0.998], "weight_decay": 1e-6},
|
||||
"SGD": {"lr": 1e-5, "momentum": 0.9},
|
||||
},
|
||||
)
|
||||
|
||||
batch = dict({})
|
||||
batch["text_input"] = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
batch["text_lengths"] = torch.randint(100, 129, (8,)).long().to(device)
|
||||
batch["text_lengths"] = torch.sort(batch["text_lengths"], descending=True)[0]
|
||||
batch["text_lengths"][0] = 128
|
||||
batch["linear_input"] = torch.rand(8, 120, config.audio["fft_size"] // 2 + 1).to(device)
|
||||
batch["mel_input"] = torch.rand(8, 120, config.audio["num_mels"]).to(device)
|
||||
batch["mel_lengths"] = torch.randint(20, 120, (8,)).long().to(device)
|
||||
batch["mel_lengths"] = torch.sort(batch["mel_lengths"], descending=True)[0]
|
||||
batch["mel_lengths"][0] = 120
|
||||
batch["stop_targets"] = torch.zeros(8, 120, 1).float().to(device)
|
||||
batch["stop_target_lengths"] = torch.randint(0, 120, (8,)).to(device)
|
||||
batch["speaker_ids"] = torch.randint(0, 5, (8,)).long().to(device)
|
||||
batch["d_vectors"] = None
|
||||
|
||||
for idx in batch["mel_lengths"]:
|
||||
batch["stop_targets"][:, int(idx.item()) :, 0] = 1.0
|
||||
|
||||
batch["stop_targets"] = batch["stop_targets"].view(
|
||||
batch["text_input"].shape[0], batch["stop_targets"].size(1) // config.r, -1
|
||||
)
|
||||
batch["stop_targets"] = (batch["stop_targets"].sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
model = Tacotron(config).to(device)
|
||||
criterion = model.get_criterion()
|
||||
optimizer = model.get_optimizer()
|
||||
model.train()
|
||||
print(" > Num parameters for Tacotron with Capacitron VAE model:%s" % (count_parameters(model)))
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
assert (param - param_ref).sum() == 0, param
|
||||
count += 1
|
||||
for _ in range(10):
|
||||
_, loss_dict = model.train_step(batch, criterion)
|
||||
optimizer.zero_grad()
|
||||
loss_dict["capacitron_vae_beta_loss"].backward()
|
||||
optimizer.first_step()
|
||||
loss_dict["loss"].backward()
|
||||
optimizer.step()
|
||||
# check parameter changes
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(), model_ref.parameters()):
|
||||
# ignore pre-higway layer since it works conditional
|
||||
assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
|
||||
count, param.shape, param, param_ref
|
||||
)
|
||||
count += 1
|
||||
|
||||
|
||||
class SCGSTMultiSpeakeTacotronTrainTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def test_train_step():
|
||||
|
|
Loading…
Reference in New Issue