Merge branch 'dev' of https://github.com/mozilla/TTS into dev

This commit is contained in:
erogol 2020-06-05 13:28:39 +02:00
commit d00b91710a
13 changed files with 399 additions and 165 deletions

View File

@ -11,15 +11,18 @@ If you are new, you can also find [here](http://www.erogol.com/text-speech-deep-
[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/0)](https://sourcerer.io/fame/erogol/erogol/TTS/links/0)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/1)](https://sourcerer.io/fame/erogol/erogol/TTS/links/1)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/2)](https://sourcerer.io/fame/erogol/erogol/TTS/links/2)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/3)](https://sourcerer.io/fame/erogol/erogol/TTS/links/3)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/4)](https://sourcerer.io/fame/erogol/erogol/TTS/links/4)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/5)](https://sourcerer.io/fame/erogol/erogol/TTS/links/5)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/6)](https://sourcerer.io/fame/erogol/erogol/TTS/links/6)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/7)](https://sourcerer.io/fame/erogol/erogol/TTS/links/7) [![](https://sourcerer.io/fame/erogol/erogol/TTS/images/0)](https://sourcerer.io/fame/erogol/erogol/TTS/links/0)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/1)](https://sourcerer.io/fame/erogol/erogol/TTS/links/1)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/2)](https://sourcerer.io/fame/erogol/erogol/TTS/links/2)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/3)](https://sourcerer.io/fame/erogol/erogol/TTS/links/3)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/4)](https://sourcerer.io/fame/erogol/erogol/TTS/links/4)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/5)](https://sourcerer.io/fame/erogol/erogol/TTS/links/5)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/6)](https://sourcerer.io/fame/erogol/erogol/TTS/links/6)[![](https://sourcerer.io/fame/erogol/erogol/TTS/images/7)](https://sourcerer.io/fame/erogol/erogol/TTS/links/7)
## TTS Performance ## TTS Performance
<p align="center"><img src="https://camo.githubusercontent.com/9fa79f977015e55eb9ec7aa32045555f60d093d3/68747470733a2f2f646973636f757273652d706161732d70726f64756374696f6e2d636f6e74656e742e73332e6475616c737461636b2e75732d656173742d312e616d617a6f6e6177732e636f6d2f6f7074696d697a65642f33582f362f342f363432386639383065396563373531633234386535393134363038393566373838316165633063365f325f363930783339342e706e67"/></p> <p align="center"><img src="https://camo.githubusercontent.com/9fa79f977015e55eb9ec7aa32045555f60d093d3/68747470733a2f2f646973636f757273652d706161732d70726f64756374696f6e2d636f6e74656e742e73332e6475616c737461636b2e75732d656173742d312e616d617a6f6e6177732e636f6d2f6f7074696d697a65642f33582f362f342f363432386639383065396563373531633234386535393134363038393566373838316165633063365f325f363930783339342e706e67"/></p>
[Details...](https://github.com/mozilla/TTS/wiki/Mean-Opinion-Score-Results) [Details...](https://github.com/mozilla/TTS/wiki/Mean-Opinion-Score-Results)
## Features ## Features
- High performance Text2Speech models on Torch and Tensorflow 2.0. - High performance Deep Learning models for Text2Speech related tasks.
- High performance Speaker Encoder to compute speaker embeddings efficiently. - Text2Speech models (Tacotron, Tacotron2).
- Integration with various Neural Vocoders (PWGAN, MelGAN, WaveRNN) - Speaker Encoder to compute speaker embeddings efficiently.
- Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS)
- Support for multi-speaker TTS training.
- Ability to convert Torch models to Tensorflow 2.0 for inference.
- Released trained models. - Released trained models.
- Efficient training codes for PyTorch. (soon for Tensorflow 2.0) - Efficient training codes for PyTorch. (soon for Tensorflow 2.0)
- Codes to convert Torch models to Tensorflow 2.0. - Codes to convert Torch models to Tensorflow 2.0.
@ -84,7 +87,7 @@ Audio length is approximately 6 secs.
## Datasets and Data-Loading ## Datasets and Data-Loading
TTS provides a generic dataloder easy to use for new datasets. You need to write an preprocessor function to integrate your own dataset.Check ```datasets/preprocess.py``` to see some examples. After the function, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields too. TTS provides a generic dataloder easy to use for new datasets. You need to write an preprocessor function to integrate your own dataset.Check ```datasets/preprocess.py``` to see some examples. After the function, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields too.
Some of the open-sourced datasets that we successfully applied TTS, are linked below. Some of the open-sourced datasets that we successfully applied TTS, are linked below.
@ -96,9 +99,9 @@ Some of the open-sourced datasets that we successfully applied TTS, are linked b
- [Spanish](https://drive.google.com/file/d/1Sm_zyBo67XHkiFhcRSQ4YaHPYM0slO_e/view?usp=sharing) - thx! @carlfm01 - [Spanish](https://drive.google.com/file/d/1Sm_zyBo67XHkiFhcRSQ4YaHPYM0slO_e/view?usp=sharing) - thx! @carlfm01
## Training and Fine-tuning LJ-Speech ## Training and Fine-tuning LJ-Speech
Here you can find a [CoLab](https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b) notebook for a hands-on example, training LJSpeech. Or you can manually follow the guideline below. Here you can find a [CoLab](https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b) notebook for a hands-on example, training LJSpeech. Or you can manually follow the guideline below.
To start with, split ```metadata.csv``` into train and validation subsets respectively ```metadata_train.csv``` and ```metadata_val.csv```. Note that for text-to-speech, validation performance might be misleading since the loss value does not directly measure the voice quality to the human ear and it also does not measure the attention module performance. Therefore, running the model with new sentences and listening to the results is the best way to go. To start with, split ```metadata.csv``` into train and validation subsets respectively ```metadata_train.csv``` and ```metadata_val.csv```. Note that for text-to-speech, validation performance might be misleading since the loss value does not directly measure the voice quality to the human ear and it also does not measure the attention module performance. Therefore, running the model with new sentences and listening to the results is the best way to go.
``` ```
shuf metadata.csv > metadata_shuf.csv shuf metadata.csv > metadata_shuf.csv
@ -137,10 +140,10 @@ cardboardlinter --refspec master
``` ```
## Collaborative Experimentation Guide ## Collaborative Experimentation Guide
If you like to use TTS to try a new idea and like to share your experiments with the community, we urge you to use the following guideline for a better collaboration. If you like to use TTS to try a new idea and like to share your experiments with the community, we urge you to use the following guideline for a better collaboration.
(If you have an idea for better collaboration, let us know) (If you have an idea for better collaboration, let us know)
- Create a new branch. - Create a new branch.
- Open an issue pointing your branch. - Open an issue pointing your branch.
- Explain your experiment. - Explain your experiment.
- Share your results as you proceed. (Tensorboard log files, audio results, visuals etc.) - Share your results as you proceed. (Tensorboard log files, audio results, visuals etc.)
- Use LJSpeech dataset (for English) if you like to compare results with the released models. (It is the most open scalable dataset for quick experimentation) - Use LJSpeech dataset (for English) if you like to compare results with the released models. (It is the most open scalable dataset for quick experimentation)

View File

@ -96,6 +96,8 @@
"transition_agent": false, // enable/disable transition agent of forward attention. "transition_agent": false, // enable/disable transition agent of forward attention.
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default. "location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset. "bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
"double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/
"ddc_r": 7, // reduction rate for coarse decoder.
// STOPNET // STOPNET
"stopnet": true, // Train stopnet predicting the end of synthesis. "stopnet": true, // Train stopnet predicting the end of synthesis.

View File

@ -184,7 +184,7 @@ class TacotronLoss(torch.nn.Module):
def forward(self, postnet_output, decoder_output, mel_input, linear_input, def forward(self, postnet_output, decoder_output, mel_input, linear_input,
stopnet_output, stopnet_target, output_lens, decoder_b_output, stopnet_output, stopnet_target, output_lens, decoder_b_output,
alignments, alignment_lens, input_lens): alignments, alignment_lens, alignments_backwards, input_lens):
return_dict = {} return_dict = {}
# decoder and postnet losses # decoder and postnet losses
@ -226,6 +226,15 @@ class TacotronLoss(torch.nn.Module):
return_dict['decoder_b_loss'] = decoder_b_loss return_dict['decoder_b_loss'] = decoder_b_loss
return_dict['decoder_c_loss'] = decoder_c_loss return_dict['decoder_c_loss'] = decoder_c_loss
# double decoder consistency loss (if enabled)
if self.config.double_decoder_consistency:
decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens)
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
loss += decoder_b_loss + attention_c_loss
return_dict['decoder_coarse_loss'] = decoder_b_loss
return_dict['decoder_ddc_loss'] = attention_c_loss
# guided attention loss (if enabled) # guided attention loss (if enabled)
if self.config.ga_alpha > 0: if self.config.ga_alpha > 0:
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens) ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)

View File

@ -1,23 +1,21 @@
# coding: utf-8 # coding: utf-8
import torch import torch
import copy
from torch import nn from torch import nn
from TTS.layers.tacotron import Encoder, Decoder, PostCBHG
from TTS.utils.generic_utils import sequence_mask
from TTS.layers.gst_layers import GST from TTS.layers.gst_layers import GST
from TTS.layers.tacotron import Decoder, Encoder, PostCBHG
from TTS.models.tacotron_abstract import TacotronAbstract
class Tacotron(nn.Module): class Tacotron(TacotronAbstract):
def __init__(self, def __init__(self,
num_chars, num_chars,
num_speakers, num_speakers,
r=5, r=5,
postnet_output_dim=1025, postnet_output_dim=1025,
decoder_output_dim=80, decoder_output_dim=80,
memory_size=5,
attn_type='original', attn_type='original',
attn_win=False, attn_win=False,
gst=False,
attn_norm="sigmoid", attn_norm="sigmoid",
prenet_type="original", prenet_type="original",
prenet_dropout=True, prenet_dropout=True,
@ -27,38 +25,41 @@ class Tacotron(nn.Module):
location_attn=True, location_attn=True,
attn_K=5, attn_K=5,
separate_stopnet=True, separate_stopnet=True,
bidirectional_decoder=False): bidirectional_decoder=False,
super(Tacotron, self).__init__() double_decoder_consistency=False,
self.r = r ddc_r=None,
self.decoder_output_dim = decoder_output_dim gst=False,
self.postnet_output_dim = postnet_output_dim memory_size=5):
self.gst = gst super(Tacotron,
self.num_speakers = num_speakers self).__init__(num_chars, num_speakers, r, postnet_output_dim,
self.bidirectional_decoder = bidirectional_decoder decoder_output_dim, attn_type, attn_win,
decoder_dim = 512 if num_speakers > 1 else 256 attn_norm, prenet_type, prenet_dropout,
encoder_dim = 512 if num_speakers > 1 else 256 forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet,
bidirectional_decoder, double_decoder_consistency,
ddc_r, gst)
decoder_in_features = 512 if num_speakers > 1 else 256
encoder_in_features = 512 if num_speakers > 1 else 256
speaker_embedding_dim = 256
proj_speaker_dim = 80 if num_speakers > 1 else 0 proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer # base model layers
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
# boilerplate model self.encoder = Encoder(encoder_in_features)
self.encoder = Encoder(encoder_dim) self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, memory_size, attn_type, attn_win,
self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_type, attn_win,
attn_norm, prenet_type, prenet_dropout, attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask, forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet, location_attn, attn_K, separate_stopnet,
proj_speaker_dim) proj_speaker_dim)
if self.bidirectional_decoder:
self.decoder_backward = copy.deepcopy(self.decoder)
self.postnet = PostCBHG(decoder_output_dim) self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
postnet_output_dim) postnet_output_dim)
# speaker embedding layers # speaker embedding layers
if num_speakers > 1: if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256) self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
self.speaker_project_mel = nn.Sequential( self.speaker_project_mel = nn.Sequential(
nn.Linear(256, proj_speaker_dim), nn.Tanh()) nn.Linear(speaker_embedding_dim, proj_speaker_dim), nn.Tanh())
self.speaker_embeddings = None self.speaker_embeddings = None
self.speaker_embeddings_projected = None self.speaker_embeddings_projected = None
# global style token layers # global style token layers
@ -68,28 +69,15 @@ class Tacotron(nn.Module):
num_heads=4, num_heads=4,
num_style_tokens=10, num_style_tokens=10,
embedding_dim=gst_embedding_dim) embedding_dim=gst_embedding_dim)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
# setup DDC
if self.double_decoder_consistency:
self._init_coarse_decoder()
def _init_states(self):
self.speaker_embeddings = None
self.speaker_embeddings_projected = None
def compute_speaker_embedding(self, speaker_ids): def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None):
if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(
" [!] Model has speaker embedding layer but speaker_id is not provided"
)
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
self.speaker_embeddings = self._compute_speaker_embedding(
speaker_ids)
self.speaker_embeddings_projected = self.speaker_project_mel(
self.speaker_embeddings).squeeze(1)
def compute_gst(self, inputs, mel_specs):
gst_outputs = self.gst_layer(mel_specs)
inputs = self._add_speaker_embedding(inputs, gst_outputs)
return inputs
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
""" """
Shapes: Shapes:
- characters: B x T_in - characters: B x T_in
@ -98,45 +86,59 @@ class Tacotron(nn.Module):
- speaker_ids: B x 1 - speaker_ids: B x 1
""" """
self._init_states() self._init_states()
mask = sequence_mask(text_lengths).to(characters.device) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x T_in x embed_dim # B x T_in x embed_dim
inputs = self.embedding(characters) inputs = self.embedding(characters)
# B x speaker_embed_dim # B x speaker_embed_dim
self.compute_speaker_embedding(speaker_ids) if speaker_ids is not None:
self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1: if self.num_speakers > 1:
# B x T_in x embed_dim + speaker_embed_dim # B x T_in x embed_dim + speaker_embed_dim
inputs = self._concat_speaker_embedding(inputs, inputs = self._concat_speaker_embedding(inputs,
self.speaker_embeddings) self.speaker_embeddings)
# B x T_in x encoder_dim # B x T_in x encoder_in_features
encoder_outputs = self.encoder(inputs) encoder_outputs = self.encoder(inputs)
# sequence masking
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# global style token
if self.gst: if self.gst:
# B x gst_dim # B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
if self.num_speakers > 1: if self.num_speakers > 1:
encoder_outputs = self._concat_speaker_embedding( encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, self.speaker_embeddings) encoder_outputs, self.speaker_embeddings)
# decoder_outputs: B x decoder_dim x T_out # decoder_outputs: B x decoder_in_features x T_out
# alignments: B x T_in x encoder_dim # alignments: B x T_in x encoder_in_features
# stop_tokens: B x T_in # stop_tokens: B x T_in
decoder_outputs, alignments, stop_tokens = self.decoder( decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask, encoder_outputs, mel_specs, input_mask,
self.speaker_embeddings_projected) self.speaker_embeddings_projected)
# B x T_out x decoder_dim # sequence masking
if output_mask is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
# B x T_out x decoder_in_features
postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.postnet(decoder_outputs)
# sequence masking
if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs)
# B x T_out x posnet_dim # B x T_out x posnet_dim
postnet_outputs = self.last_linear(postnet_outputs) postnet_outputs = self.last_linear(postnet_outputs)
# B x T_out x decoder_dim # B x T_out x decoder_in_features
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
if self.bidirectional_decoder: if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask) decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return decoder_outputs, postnet_outputs, alignments, stop_tokens return decoder_outputs, postnet_outputs, alignments, stop_tokens
@torch.no_grad() @torch.no_grad()
def inference(self, characters, speaker_ids=None, style_mel=None): def inference(self, characters, speaker_ids=None, style_mel=None):
inputs = self.embedding(characters) inputs = self.embedding(characters)
self._init_states() self._init_states()
self.compute_speaker_embedding(speaker_ids) if speaker_ids is not None:
self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1: if self.num_speakers > 1:
inputs = self._concat_speaker_embedding(inputs, inputs = self._concat_speaker_embedding(inputs,
self.speaker_embeddings) self.speaker_embeddings)
@ -152,28 +154,3 @@ class Tacotron(nn.Module):
postnet_outputs = self.last_linear(postnet_outputs) postnet_outputs = self.last_linear(postnet_outputs)
decoder_outputs = decoder_outputs.transpose(1, 2) decoder_outputs = decoder_outputs.transpose(1, 2)
return decoder_outputs, postnet_outputs, alignments, stop_tokens return decoder_outputs, postnet_outputs, alignments, stop_tokens
def _backward_inference(self, mel_specs, encoder_outputs, mask):
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
self.speaker_embeddings_projected)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
return decoder_outputs_b, alignments_b
def _compute_speaker_embedding(self, speaker_ids):
speaker_embeddings = self.speaker_embedding(speaker_ids)
return speaker_embeddings.unsqueeze_(1)
@staticmethod
def _add_speaker_embedding(outputs, speaker_embeddings):
speaker_embeddings_ = speaker_embeddings.expand(
outputs.size(0), outputs.size(1), -1)
outputs = outputs + speaker_embeddings_
return outputs
@staticmethod
def _concat_speaker_embedding(outputs, speaker_embeddings):
speaker_embeddings_ = speaker_embeddings.expand(
outputs.size(0), outputs.size(1), -1)
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
return outputs

View File

@ -1,13 +1,15 @@
import copy
import torch
from math import sqrt from math import sqrt
import torch
from torch import nn from torch import nn
from TTS.layers.tacotron2 import Encoder, Decoder, Postnet
from TTS.utils.generic_utils import sequence_mask from TTS.layers.gst_layers import GST
from TTS.layers.tacotron2 import Decoder, Encoder, Postnet
from TTS.models.tacotron_abstract import TacotronAbstract
# TODO: match function arguments with tacotron # TODO: match function arguments with tacotron
class Tacotron2(nn.Module): class Tacotron2(TacotronAbstract):
def __init__(self, def __init__(self,
num_chars, num_chars,
num_speakers, num_speakers,
@ -25,16 +27,22 @@ class Tacotron2(nn.Module):
location_attn=True, location_attn=True,
attn_K=5, attn_K=5,
separate_stopnet=True, separate_stopnet=True,
bidirectional_decoder=False): bidirectional_decoder=False,
super(Tacotron2, self).__init__() double_decoder_consistency=False,
self.postnet_output_dim = postnet_output_dim ddc_r=None,
self.decoder_output_dim = decoder_output_dim gst=False):
self.r = r super(Tacotron2,
self.bidirectional_decoder = bidirectional_decoder self).__init__(num_chars, num_speakers, r, postnet_output_dim,
decoder_dim = 512 if num_speakers > 1 else 512 decoder_output_dim, attn_type, attn_win,
encoder_dim = 512 if num_speakers > 1 else 512 attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet,
bidirectional_decoder, double_decoder_consistency,
ddc_r, gst)
decoder_in_features = 512 if num_speakers > 1 else 512
encoder_in_features = 512 if num_speakers > 1 else 512
proj_speaker_dim = 80 if num_speakers > 1 else 0 proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer # base layers
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
std = sqrt(2.0 / (num_chars + 512)) std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std
@ -42,20 +50,25 @@ class Tacotron2(nn.Module):
if num_speakers > 1: if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512) self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
self.speaker_embeddings = None self.encoder = Encoder(encoder_in_features)
self.speaker_embeddings_projected = None self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
self.encoder = Encoder(encoder_dim)
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_type, attn_win,
attn_norm, prenet_type, prenet_dropout, attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask, forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet, proj_speaker_dim) location_attn, attn_K, separate_stopnet, proj_speaker_dim)
if self.bidirectional_decoder:
self.decoder_backward = copy.deepcopy(self.decoder)
self.postnet = Postnet(self.postnet_output_dim) self.postnet = Postnet(self.postnet_output_dim)
# global style token layers
def _init_states(self): if self.gst:
self.speaker_embeddings = None gst_embedding_dim = encoder_in_features
self.speaker_embeddings_projected = None self.gst_layer = GST(num_mel=80,
num_heads=4,
num_style_tokens=10,
embedding_dim=gst_embedding_dim)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
# setup DDC
if self.double_decoder_consistency:
self._init_coarse_decoder()
@staticmethod @staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
@ -63,31 +76,60 @@ class Tacotron2(nn.Module):
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
return mel_outputs, mel_outputs_postnet, alignments return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None): def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
self._init_states() self._init_states()
# compute mask for padding # compute mask for padding
mask = sequence_mask(text_lengths).to(text.device) # B x T_in_max (boolean)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x D_embed x T_in_max
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
# B x T_in_max x D_en
encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self.encoder(embedded_inputs, text_lengths)
encoder_outputs = self._add_speaker_embedding(encoder_outputs, # adding speaker embeddding to encoder output
speaker_ids) # TODO: multi-speaker
# B x speaker_embed_dim
if speaker_ids is not None:
self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1:
# B x T_in x embed_dim + speaker_embed_dim
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
self.speaker_embeddings)
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# global style token
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
decoder_outputs, alignments, stop_tokens = self.decoder( decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, mask) encoder_outputs, mel_specs, input_mask)
# sequence masking
if mel_lengths is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
# B x mel_dim x T_out
postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs # sequence masking
if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
decoder_outputs, postnet_outputs, alignments = self.shape_outputs( decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
decoder_outputs, postnet_outputs, alignments) decoder_outputs, postnet_outputs, alignments)
if self.bidirectional_decoder: if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask) decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return decoder_outputs, postnet_outputs, alignments, stop_tokens return decoder_outputs, postnet_outputs, alignments, stop_tokens
@torch.no_grad() @torch.no_grad()
def inference(self, text, speaker_ids=None): def inference(self, text, speaker_ids=None):
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs) encoder_outputs = self.encoder.inference(embedded_inputs)
encoder_outputs = self._add_speaker_embedding(encoder_outputs, if speaker_ids is not None:
speaker_ids) self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1:
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
self.speaker_embeddings)
mel_outputs, alignments, stop_tokens = self.decoder.inference( mel_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs) encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = self.postnet(mel_outputs)
@ -112,22 +154,16 @@ class Tacotron2(nn.Module):
mel_outputs, mel_outputs_postnet, alignments) mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
def _backward_inference(self, mel_specs, encoder_outputs, mask):
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
self.speaker_embeddings_projected)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
return decoder_outputs_b, alignments_b
def _add_speaker_embedding(self, encoder_outputs, speaker_ids): def _speaker_embedding_pass(self, encoder_outputs, speaker_ids):
if hasattr(self, "speaker_embedding") and speaker_ids is None: # TODO: multi-speaker
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") # if hasattr(self, "speaker_embedding") and speaker_ids is None:
if hasattr(self, "speaker_embedding") and speaker_ids is not None: # raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
speaker_embeddings = self.speaker_embedding(speaker_ids) # if hasattr(self, "speaker_embedding") and speaker_ids is not None:
speaker_embeddings.unsqueeze_(1) # speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), # encoder_outputs.size(1),
encoder_outputs.size(1), # -1)
-1) # encoder_outputs = encoder_outputs + speaker_embeddings
encoder_outputs = encoder_outputs + speaker_embeddings # return encoder_outputs
return encoder_outputs pass

180
models/tacotron_abstract.py Normal file
View File

@ -0,0 +1,180 @@
import copy
from abc import ABC, abstractmethod
import torch
from torch import nn
from TTS.utils.generic_utils import sequence_mask
class TacotronAbstract(ABC, nn.Module):
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type='original',
attn_win=False,
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
gst=False):
""" Abstract Tacotron class """
super().__init__()
self.num_chars = num_chars
self.r = r
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
self.gst = gst
self.num_speakers = num_speakers
self.bidirectional_decoder = bidirectional_decoder
self.double_decoder_consistency = double_decoder_consistency
self.ddc_r = ddc_r
self.attn_type = attn_type
self.attn_win = attn_win
self.attn_norm = attn_norm
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.forward_attn = forward_attn
self.trans_agent = trans_agent
self.forward_attn_mask = forward_attn_mask
self.location_attn = location_attn
self.attn_K = attn_K
self.separate_stopnet = separate_stopnet
# layers
self.embedding = None
self.encoder = None
self.decoder = None
self.postnet = None
# global style token
if self.gst:
self.gst_layer = None
# model states
self.speaker_embeddings = None
self.speaker_embeddings_projected = None
# additional layers
self.decoder_backward = None
self.coarse_decoder = None
#############################
# INIT FUNCTIONS
#############################
def _init_states(self):
self.speaker_embeddings = None
self.speaker_embeddings_projected = None
def _init_backward_decoder(self):
self.decoder_backward = copy.deepcopy(self.decoder)
def _init_coarse_decoder(self):
self.coarse_decoder = copy.deepcopy(self.decoder)
self.coarse_decoder.r_init = self.ddc_r
self.coarse_decoder.set_r(self.ddc_r)
#############################
# CORE FUNCTIONS
#############################
@abstractmethod
def forward(self):
pass
@abstractmethod
def inference(self):
pass
#############################
# COMMON COMPUTE FUNCTIONS
#############################
def compute_masks(self, text_lengths, mel_lengths):
"""Compute masks against sequence paddings."""
# B x T_in_max (boolean)
device = text_lengths.device
input_mask = sequence_mask(text_lengths).to(device)
output_mask = None
if mel_lengths is not None:
max_len = mel_lengths.max()
r = self.decoder.r
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device)
return input_mask, output_mask
def _backward_pass(self, mel_specs, encoder_outputs, mask):
""" Run backwards decoder """
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
self.speaker_embeddings_projected)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
return decoder_outputs_b, alignments_b
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments,
input_mask):
""" Double Decoder Consistency """
T = mel_specs.shape[1]
if T % self.coarse_decoder.r > 0:
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
mel_specs = torch.nn.functional.pad(mel_specs,
(0, 0, 0, padding_size, 0, 0))
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
encoder_outputs.detach(), mel_specs, input_mask)
# scale_factor = self.decoder.r_init / self.decoder.r
alignments_backward = torch.nn.functional.interpolate(
alignments_backward.transpose(1, 2),
size=alignments.shape[1],
mode='nearest').transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
return decoder_outputs_backward, alignments_backward
#############################
# EMBEDDING FUNCTIONS
#############################
def compute_speaker_embedding(self, speaker_ids):
""" Compute speaker embedding vectors """
if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(
" [!] Model has speaker embedding layer but speaker_id is not provided"
)
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1)
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
self.speaker_embeddings_projected = self.speaker_project_mel(
self.speaker_embeddings).squeeze(1)
def compute_gst(self, inputs, mel_specs):
""" Compute global style token """
# pylint: disable=not-callable
gst_outputs = self.gst_layer(mel_specs)
inputs = self._add_speaker_embedding(inputs, gst_outputs)
return inputs
@staticmethod
def _add_speaker_embedding(outputs, speaker_embeddings):
speaker_embeddings_ = speaker_embeddings.expand(
outputs.size(0), outputs.size(1), -1)
outputs = outputs + speaker_embeddings_
return outputs
@staticmethod
def _concat_speaker_embedding(outputs, speaker_embeddings):
speaker_embeddings_ = speaker_embeddings.expand(
outputs.size(0), outputs.size(1), -1)
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
return outputs

View File

@ -4,7 +4,7 @@
"audio":{ "audio":{
// Audio processing parameters // Audio processing parameters
"num_mels": 80, // size of the mel spec frame. "num_mels": 80, // size of the mel spec frame.
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame. "num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
"frame_length_ms": 50, // stft window length in ms. "frame_length_ms": 50, // stft window length in ms.
@ -31,19 +31,19 @@
"reinit_layers": [], "reinit_layers": [],
"model": "Tacotron2", // one of the model in models/ "model": "Tacotron2", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping. "grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train. "epochs": 1000, // total number of epochs to train.
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training. "lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode. "windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"forward_attn_mask": false, "forward_attn_mask": false,
"attention_type": "original", "attention_type": "original",
"attention_heads": 5, "attention_heads": 5,
"bidirectional_decoder": false, "bidirectional_decoder": false,
@ -51,13 +51,15 @@
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding. "loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis. "stopnet": true, // Train stopnet predicting the end of synthesis.
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"use_gst": false, "use_gst": false,
"double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/
"ddc_r": 7, // reduction rate for coarse decoder.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16, "eval_batch_size":16,
"r": 1, // Number of frames to predict for step. "r": 1, // Number of frames to predict for step.
"wd": 0.000001, // Weight decay weight. "wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step" "checkpoint": true, // If true, it saves checkpoints per "save_step"

View File

@ -51,7 +51,7 @@ class TacotronTrainTest(unittest.TestCase):
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
for i in range(5): for i in range(5):
mel_out, mel_postnet_out, align, stop_tokens = model.forward( mel_out, mel_postnet_out, align, stop_tokens = model.forward(
input, input_lengths, mel_spec, speaker_ids) input, input_lengths, mel_spec, mel_lengths, speaker_ids)
assert torch.sigmoid(stop_tokens).data.max() <= 1.0 assert torch.sigmoid(stop_tokens).data.max() <= 1.0
assert torch.sigmoid(stop_tokens).data.min() >= 0.0 assert torch.sigmoid(stop_tokens).data.min() >= 0.0
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -66,7 +66,7 @@ class TacotronTrainTest(unittest.TestCase):
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
for _ in range(5): for _ in range(5):
mel_out, linear_out, align, stop_tokens = model.forward( mel_out, linear_out, align, stop_tokens = model.forward(
input_dummy, input_lengths, mel_spec, speaker_ids) input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
optimizer.zero_grad() optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths) loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
@ -95,6 +95,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device) mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device)
linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device) linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device)
mel_lengths = torch.randint(20, 120, (8, )).long().to(device) mel_lengths = torch.randint(20, 120, (8, )).long().to(device)
mel_lengths[-1] = 120
stop_targets = torch.zeros(8, 120, 1).float().to(device) stop_targets = torch.zeros(8, 120, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device) speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
@ -130,7 +131,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
optimizer = optim.Adam(model.parameters(), lr=c.lr) optimizer = optim.Adam(model.parameters(), lr=c.lr)
for _ in range(10): for _ in range(10):
mel_out, linear_out, align, stop_tokens = model.forward( mel_out, linear_out, align, stop_tokens = model.forward(
input_dummy, input_lengths, mel_spec, speaker_ids) input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
optimizer.zero_grad() optimizer.zero_grad()
loss = criterion(mel_out, mel_spec, mel_lengths) loss = criterion(mel_out, mel_spec, mel_lengths)
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)

View File

@ -158,13 +158,14 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
optimizer_st.zero_grad() optimizer_st.zero_grad()
# forward pass model # forward pass model
if c.bidirectional_decoder: if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids) text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
else: else:
decoder_output, postnet_output, alignments, stop_tokens = model( decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids) text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
decoder_backward_output = None decoder_backward_output = None
alignments_backward = None
# set the alignment lengths wrt reduction factor for guided attention # set the alignment lengths wrt reduction factor for guided attention
if mel_lengths.max() % model.decoder.r != 0: if mel_lengths.max() % model.decoder.r != 0:
@ -176,7 +177,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
loss_dict = criterion(postnet_output, decoder_output, mel_input, loss_dict = criterion(postnet_output, decoder_output, mel_input,
linear_input, stop_tokens, stop_targets, linear_input, stop_tokens, stop_targets,
mel_lengths, decoder_backward_output, mel_lengths, decoder_backward_output,
alignments, alignment_lengths, text_lengths) alignments, alignment_lengths, alignments_backward,
text_lengths)
if c.bidirectional_decoder: if c.bidirectional_decoder:
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(), keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(),
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()}) 'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})

View File

@ -160,13 +160,16 @@ def setup_model(num_chars, num_speakers, c):
location_attn=c.location_attn, location_attn=c.location_attn,
attn_K=c.attention_heads, attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet, separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder) bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r)
elif c.model.lower() == "tacotron2": elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars, model = MyModel(num_chars=num_chars,
num_speakers=num_speakers, num_speakers=num_speakers,
r=c.r, r=c.r,
postnet_output_dim=c.audio['num_mels'], postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'], decoder_output_dim=c.audio['num_mels'],
gst=c.use_gst,
attn_type=c.attention_type, attn_type=c.attention_type,
attn_win=c.windowing, attn_win=c.windowing,
attn_norm=c.attention_norm, attn_norm=c.attention_norm,
@ -178,7 +181,9 @@ def setup_model(num_chars, num_speakers, c):
location_attn=c.location_attn, location_attn=c.location_attn,
attn_K=c.attention_heads, attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet, separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder) bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r)
return model return model
class KeepAverage(): class KeepAverage():
@ -313,6 +318,8 @@ def check_config(c):
_check_argument('transition_agent', c, restricted=True, val_type=bool) _check_argument('transition_agent', c, restricted=True, val_type=bool)
_check_argument('location_attn', c, restricted=True, val_type=bool) _check_argument('location_attn', c, restricted=True, val_type=bool)
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) _check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
_check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
_check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
# stopnet # stopnet
_check_argument('stopnet', c, restricted=True, val_type=bool) _check_argument('stopnet', c, restricted=True, val_type=bool)

View File

@ -77,6 +77,7 @@ class MultiScaleSTFTLoss(torch.nn.Module):
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
""" Multiscale STFT loss for multi band model outputs """ """ Multiscale STFT loss for multi band model outputs """
# pylint: disable=no-self-use
def forward(self, y_hat, y): def forward(self, y_hat, y):
y_hat = y_hat.view(-1, 1, y_hat.shape[2]) y_hat = y_hat.view(-1, 1, y_hat.shape[2])
y = y.view(-1, 1, y.shape[2]) y = y.view(-1, 1, y.shape[2])
@ -85,6 +86,7 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
class MSEGLoss(nn.Module): class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """ """ Mean Squared Generator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake): def forward(self, score_fake):
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2])) loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
return loss_fake return loss_fake
@ -92,6 +94,7 @@ class MSEGLoss(nn.Module):
class HingeGLoss(nn.Module): class HingeGLoss(nn.Module):
""" Hinge Discriminator Loss """ """ Hinge Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake): def forward(self, score_fake):
loss_fake = torch.mean(F.relu(1. + score_fake)) loss_fake = torch.mean(F.relu(1. + score_fake))
return loss_fake return loss_fake
@ -104,6 +107,7 @@ class HingeGLoss(nn.Module):
class MSEDLoss(nn.Module): class MSEDLoss(nn.Module):
""" Mean Squared Discriminator Loss """ """ Mean Squared Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake, score_real): def forward(self, score_fake, score_real):
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2])) loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2])) loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
@ -113,6 +117,7 @@ class MSEDLoss(nn.Module):
class HingeDLoss(nn.Module): class HingeDLoss(nn.Module):
""" Hinge Discriminator Loss """ """ Hinge Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake, score_real): def forward(self, score_fake, score_real):
loss_real = torch.mean(F.relu(1. - score_real)) loss_real = torch.mean(F.relu(1. - score_real))
loss_fake = torch.mean(F.relu(1. + score_fake)) loss_fake = torch.mean(F.relu(1. + score_fake))
@ -121,6 +126,7 @@ class HingeDLoss(nn.Module):
class MelganFeatureLoss(nn.Module): class MelganFeatureLoss(nn.Module):
# pylint: disable=no-self-use
def forward(self, fake_feats, real_feats): def forward(self, fake_feats, real_feats):
loss_feats = 0 loss_feats = 0
for fake_feat, real_feat in zip(fake_feats, real_feats): for fake_feat, real_feat in zip(fake_feats, real_feats):
@ -193,8 +199,8 @@ class GeneratorLoss(nn.Module):
self.stft_loss_weight = C.stft_loss_weight self.stft_loss_weight = C.stft_loss_weight
self.subband_stft_loss_weight = C.subband_stft_loss_weight self.subband_stft_loss_weight = C.subband_stft_loss_weight
self.mse_gan_loss_weight = C.mse_gan_loss_weight self.mse_gan_loss_weight = C.mse_G_loss_weight
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight self.hinge_gan_loss_weight = C.hinge_G_loss_weight
self.feat_match_loss_weight = C.feat_match_loss_weight self.feat_match_loss_weight = C.feat_match_loss_weight
if C.use_stft_loss: if C.use_stft_loss:

View File

@ -52,7 +52,7 @@ def setup_loader(ap, is_val=False, verbose=False):
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None # sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset, loader = DataLoader(dataset,
batch_size=1 if is_val else c.batch_size, batch_size=1 if is_val else c.batch_size,
shuffle=False, shuffle=True,
drop_last=False, drop_last=False,
sampler=None, sampler=None,
num_workers=c.num_val_loader_workers num_workers=c.num_val_loader_workers
@ -120,11 +120,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
y_hat = model_G(c_G) y_hat = model_G(c_G)
y_hat_sub = None y_hat_sub = None
y_G_sub = None y_G_sub = None
y_hat_vis = y_hat # for visualization
# PQMF formatting # PQMF formatting
if y_hat.shape[1] > 1: if y_hat.shape[1] > 1:
y_hat_sub = y_hat y_hat_sub = y_hat
y_hat = model_G.pqmf_synthesis(y_hat) y_hat = model_G.pqmf_synthesis(y_hat)
y_hat_vis = y_hat
y_G_sub = model_G.pqmf_analysis(y_G) y_G_sub = model_G.pqmf_analysis(y_G)
if global_step > c.steps_to_start_discriminator: if global_step > c.steps_to_start_discriminator:
@ -171,7 +173,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
loss_dict = dict() loss_dict = dict()
for key, value in loss_G_dict.items(): for key, value in loss_G_dict.items():
loss_dict[key] = value.item() if isinstance(value, int):
loss_dict[key] = value
else:
loss_dict[key] = value.item()
############################## ##############################
# DISCRIMINATOR # DISCRIMINATOR
@ -265,12 +270,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
model_losses=loss_dict) model_losses=loss_dict)
# compute spectrograms # compute spectrograms
figures = plot_results(y_hat, y_G, ap, global_step, figures = plot_results(y_hat_vis, y_G, ap, global_step,
'train') 'train')
tb_logger.tb_train_figures(global_step, figures) tb_logger.tb_train_figures(global_step, figures)
# Sample audio # Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step, tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice}, {'train/audio': sample_voice},
c.audio["sample_rate"]) c.audio["sample_rate"])
@ -322,8 +327,12 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
y_hat = model_G.pqmf_synthesis(y_hat) y_hat = model_G.pqmf_synthesis(y_hat)
y_G_sub = model_G.pqmf_analysis(y_G) y_G_sub = model_G.pqmf_analysis(y_G)
D_out_fake = model_D(y_hat) if len(signature(model_D.forward).parameters) == 2:
D_out_fake = model_D(y_hat, c_G)
else:
D_out_fake = model_D(y_hat)
D_out_real = None D_out_real = None
if c.use_feat_match_loss: if c.use_feat_match_loss:
with torch.no_grad(): with torch.no_grad():
D_out_real = model_D(y_G) D_out_real = model_D(y_G)
@ -354,7 +363,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
for key, value in loss_G_dict.items(): for key, value in loss_G_dict.items():
update_eval_values['avg_' + key] = value.item() update_eval_values['avg_' + key] = value.item()
update_eval_values['avg_loader_time'] = loader_time update_eval_values['avg_loader_time'] = loader_time
update_eval_values['avgP_step_time'] = step_time update_eval_values['avg_step_time'] = step_time
keep_avg.update_values(update_eval_values) keep_avg.update_values(update_eval_values)
# print eval stats # print eval stats