mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev' of https://github.com/mozilla/TTS into dev
This commit is contained in:
commit
d00b91710a
21
README.md
21
README.md
|
@ -11,15 +11,18 @@ If you are new, you can also find [here](http://www.erogol.com/text-speech-deep-
|
|||
|
||||
[](https://sourcerer.io/fame/erogol/erogol/TTS/links/0)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/1)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/2)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/3)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/4)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/5)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/6)[](https://sourcerer.io/fame/erogol/erogol/TTS/links/7)
|
||||
|
||||
## TTS Performance
|
||||
## TTS Performance
|
||||
<p align="center"><img src="https://camo.githubusercontent.com/9fa79f977015e55eb9ec7aa32045555f60d093d3/68747470733a2f2f646973636f757273652d706161732d70726f64756374696f6e2d636f6e74656e742e73332e6475616c737461636b2e75732d656173742d312e616d617a6f6e6177732e636f6d2f6f7074696d697a65642f33582f362f342f363432386639383065396563373531633234386535393134363038393566373838316165633063365f325f363930783339342e706e67"/></p>
|
||||
|
||||
[Details...](https://github.com/mozilla/TTS/wiki/Mean-Opinion-Score-Results)
|
||||
|
||||
## Features
|
||||
- High performance Text2Speech models on Torch and Tensorflow 2.0.
|
||||
- High performance Speaker Encoder to compute speaker embeddings efficiently.
|
||||
- Integration with various Neural Vocoders (PWGAN, MelGAN, WaveRNN)
|
||||
- High performance Deep Learning models for Text2Speech related tasks.
|
||||
- Text2Speech models (Tacotron, Tacotron2).
|
||||
- Speaker Encoder to compute speaker embeddings efficiently.
|
||||
- Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS)
|
||||
- Support for multi-speaker TTS training.
|
||||
- Ability to convert Torch models to Tensorflow 2.0 for inference.
|
||||
- Released trained models.
|
||||
- Efficient training codes for PyTorch. (soon for Tensorflow 2.0)
|
||||
- Codes to convert Torch models to Tensorflow 2.0.
|
||||
|
@ -84,7 +87,7 @@ Audio length is approximately 6 secs.
|
|||
|
||||
|
||||
## Datasets and Data-Loading
|
||||
TTS provides a generic dataloder easy to use for new datasets. You need to write an preprocessor function to integrate your own dataset.Check ```datasets/preprocess.py``` to see some examples. After the function, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields too.
|
||||
TTS provides a generic dataloder easy to use for new datasets. You need to write an preprocessor function to integrate your own dataset.Check ```datasets/preprocess.py``` to see some examples. After the function, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields too.
|
||||
|
||||
Some of the open-sourced datasets that we successfully applied TTS, are linked below.
|
||||
|
||||
|
@ -96,9 +99,9 @@ Some of the open-sourced datasets that we successfully applied TTS, are linked b
|
|||
- [Spanish](https://drive.google.com/file/d/1Sm_zyBo67XHkiFhcRSQ4YaHPYM0slO_e/view?usp=sharing) - thx! @carlfm01
|
||||
|
||||
## Training and Fine-tuning LJ-Speech
|
||||
Here you can find a [CoLab](https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b) notebook for a hands-on example, training LJSpeech. Or you can manually follow the guideline below.
|
||||
Here you can find a [CoLab](https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b) notebook for a hands-on example, training LJSpeech. Or you can manually follow the guideline below.
|
||||
|
||||
To start with, split ```metadata.csv``` into train and validation subsets respectively ```metadata_train.csv``` and ```metadata_val.csv```. Note that for text-to-speech, validation performance might be misleading since the loss value does not directly measure the voice quality to the human ear and it also does not measure the attention module performance. Therefore, running the model with new sentences and listening to the results is the best way to go.
|
||||
To start with, split ```metadata.csv``` into train and validation subsets respectively ```metadata_train.csv``` and ```metadata_val.csv```. Note that for text-to-speech, validation performance might be misleading since the loss value does not directly measure the voice quality to the human ear and it also does not measure the attention module performance. Therefore, running the model with new sentences and listening to the results is the best way to go.
|
||||
|
||||
```
|
||||
shuf metadata.csv > metadata_shuf.csv
|
||||
|
@ -137,10 +140,10 @@ cardboardlinter --refspec master
|
|||
```
|
||||
|
||||
## Collaborative Experimentation Guide
|
||||
If you like to use TTS to try a new idea and like to share your experiments with the community, we urge you to use the following guideline for a better collaboration.
|
||||
If you like to use TTS to try a new idea and like to share your experiments with the community, we urge you to use the following guideline for a better collaboration.
|
||||
(If you have an idea for better collaboration, let us know)
|
||||
- Create a new branch.
|
||||
- Open an issue pointing your branch.
|
||||
- Open an issue pointing your branch.
|
||||
- Explain your experiment.
|
||||
- Share your results as you proceed. (Tensorboard log files, audio results, visuals etc.)
|
||||
- Use LJSpeech dataset (for English) if you like to compare results with the released models. (It is the most open scalable dataset for quick experimentation)
|
||||
|
|
|
@ -96,6 +96,8 @@
|
|||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||
"double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/
|
||||
"ddc_r": 7, // reduction rate for coarse decoder.
|
||||
|
||||
// STOPNET
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
|
|
|
@ -184,7 +184,7 @@ class TacotronLoss(torch.nn.Module):
|
|||
|
||||
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
|
||||
stopnet_output, stopnet_target, output_lens, decoder_b_output,
|
||||
alignments, alignment_lens, input_lens):
|
||||
alignments, alignment_lens, alignments_backwards, input_lens):
|
||||
|
||||
return_dict = {}
|
||||
# decoder and postnet losses
|
||||
|
@ -226,6 +226,15 @@ class TacotronLoss(torch.nn.Module):
|
|||
return_dict['decoder_b_loss'] = decoder_b_loss
|
||||
return_dict['decoder_c_loss'] = decoder_c_loss
|
||||
|
||||
# double decoder consistency loss (if enabled)
|
||||
if self.config.double_decoder_consistency:
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens)
|
||||
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
|
||||
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
|
||||
loss += decoder_b_loss + attention_c_loss
|
||||
return_dict['decoder_coarse_loss'] = decoder_b_loss
|
||||
return_dict['decoder_ddc_loss'] = attention_c_loss
|
||||
|
||||
# guided attention loss (if enabled)
|
||||
if self.config.ga_alpha > 0:
|
||||
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
||||
|
|
|
@ -1,23 +1,21 @@
|
|||
# coding: utf-8
|
||||
import torch
|
||||
import copy
|
||||
from torch import nn
|
||||
from TTS.layers.tacotron import Encoder, Decoder, PostCBHG
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
||||
from TTS.layers.gst_layers import GST
|
||||
from TTS.layers.tacotron import Decoder, Encoder, PostCBHG
|
||||
from TTS.models.tacotron_abstract import TacotronAbstract
|
||||
|
||||
|
||||
class Tacotron(nn.Module):
|
||||
class Tacotron(TacotronAbstract):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r=5,
|
||||
postnet_output_dim=1025,
|
||||
decoder_output_dim=80,
|
||||
memory_size=5,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
gst=False,
|
||||
attn_norm="sigmoid",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
|
@ -27,38 +25,41 @@ class Tacotron(nn.Module):
|
|||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron, self).__init__()
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.gst = gst
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
decoder_dim = 512 if num_speakers > 1 else 256
|
||||
encoder_dim = 512 if num_speakers > 1 else 256
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False,
|
||||
memory_size=5):
|
||||
super(Tacotron,
|
||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
decoder_output_dim, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, gst)
|
||||
decoder_in_features = 512 if num_speakers > 1 else 256
|
||||
encoder_in_features = 512 if num_speakers > 1 else 256
|
||||
speaker_embedding_dim = 256
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# embedding layer
|
||||
# base model layers
|
||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
# boilerplate model
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
self.decoder = Decoder(decoder_dim, decoder_output_dim, r, memory_size, attn_type, attn_win,
|
||||
self.encoder = Encoder(encoder_in_features)
|
||||
self.decoder = Decoder(decoder_in_features, decoder_output_dim, r, memory_size, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
proj_speaker_dim)
|
||||
if self.bidirectional_decoder:
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
||||
postnet_output_dim)
|
||||
# speaker embedding layers
|
||||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.speaker_project_mel = nn.Sequential(
|
||||
nn.Linear(256, proj_speaker_dim), nn.Tanh())
|
||||
nn.Linear(speaker_embedding_dim, proj_speaker_dim), nn.Tanh())
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
# global style token layers
|
||||
|
@ -68,28 +69,15 @@ class Tacotron(nn.Module):
|
|||
num_heads=4,
|
||||
num_style_tokens=10,
|
||||
embedding_dim=gst_embedding_dim)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
# setup DDC
|
||||
if self.double_decoder_consistency:
|
||||
self._init_coarse_decoder()
|
||||
|
||||
def _init_states(self):
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(
|
||||
" [!] Model has speaker embedding layer but speaker_id is not provided"
|
||||
)
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.speaker_embeddings = self._compute_speaker_embedding(
|
||||
speaker_ids)
|
||||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||
self.speaker_embeddings).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, mel_specs):
|
||||
gst_outputs = self.gst_layer(mel_specs)
|
||||
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
||||
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None):
|
||||
"""
|
||||
Shapes:
|
||||
- characters: B x T_in
|
||||
|
@ -98,45 +86,59 @@ class Tacotron(nn.Module):
|
|||
- speaker_ids: B x 1
|
||||
"""
|
||||
self._init_states()
|
||||
mask = sequence_mask(text_lengths).to(characters.device)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
# B x T_in x embed_dim
|
||||
inputs = self.embedding(characters)
|
||||
# B x speaker_embed_dim
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
# B x T_in x embed_dim + speaker_embed_dim
|
||||
inputs = self._concat_speaker_embedding(inputs,
|
||||
self.speaker_embeddings)
|
||||
# B x T_in x encoder_dim
|
||||
# B x T_in x encoder_in_features
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# sequence masking
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
# global style token
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
if self.num_speakers > 1:
|
||||
encoder_outputs = self._concat_speaker_embedding(
|
||||
encoder_outputs, self.speaker_embeddings)
|
||||
# decoder_outputs: B x decoder_dim x T_out
|
||||
# alignments: B x T_in x encoder_dim
|
||||
# decoder_outputs: B x decoder_in_features x T_out
|
||||
# alignments: B x T_in x encoder_in_features
|
||||
# stop_tokens: B x T_in
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, mask,
|
||||
encoder_outputs, mel_specs, input_mask,
|
||||
self.speaker_embeddings_projected)
|
||||
# B x T_out x decoder_dim
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||
# B x T_out x decoder_in_features
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs)
|
||||
# B x T_out x posnet_dim
|
||||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
# B x T_out x decoder_dim
|
||||
# B x T_out x decoder_in_features
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, characters, speaker_ids=None, style_mel=None):
|
||||
inputs = self.embedding(characters)
|
||||
self._init_states()
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
inputs = self._concat_speaker_embedding(inputs,
|
||||
self.speaker_embeddings)
|
||||
|
@ -152,28 +154,3 @@ class Tacotron(nn.Module):
|
|||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||
self.speaker_embeddings_projected)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _compute_speaker_embedding(self, speaker_ids):
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||
return speaker_embeddings.unsqueeze_(1)
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + speaker_embeddings_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import copy
|
||||
import torch
|
||||
from math import sqrt
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from TTS.layers.tacotron2 import Encoder, Decoder, Postnet
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
||||
from TTS.layers.gst_layers import GST
|
||||
from TTS.layers.tacotron2 import Decoder, Encoder, Postnet
|
||||
from TTS.models.tacotron_abstract import TacotronAbstract
|
||||
|
||||
|
||||
# TODO: match function arguments with tacotron
|
||||
class Tacotron2(nn.Module):
|
||||
class Tacotron2(TacotronAbstract):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
|
@ -25,16 +27,22 @@ class Tacotron2(nn.Module):
|
|||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.r = r
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
decoder_dim = 512 if num_speakers > 1 else 512
|
||||
encoder_dim = 512 if num_speakers > 1 else 512
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False):
|
||||
super(Tacotron2,
|
||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
decoder_output_dim, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, gst)
|
||||
decoder_in_features = 512 if num_speakers > 1 else 512
|
||||
encoder_in_features = 512 if num_speakers > 1 else 512
|
||||
proj_speaker_dim = 80 if num_speakers > 1 else 0
|
||||
# embedding layer
|
||||
# base layers
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
std = sqrt(2.0 / (num_chars + 512))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
|
@ -42,20 +50,25 @@ class Tacotron2(nn.Module):
|
|||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
self.encoder = Encoder(encoder_dim)
|
||||
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_type, attn_win,
|
||||
self.encoder = Encoder(encoder_in_features)
|
||||
self.decoder = Decoder(decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet, proj_speaker_dim)
|
||||
if self.bidirectional_decoder:
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
self.postnet = Postnet(self.postnet_output_dim)
|
||||
|
||||
def _init_states(self):
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
gst_embedding_dim = encoder_in_features
|
||||
self.gst_layer = GST(num_mel=80,
|
||||
num_heads=4,
|
||||
num_style_tokens=10,
|
||||
embedding_dim=gst_embedding_dim)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
# setup DDC
|
||||
if self.double_decoder_consistency:
|
||||
self._init_coarse_decoder()
|
||||
|
||||
@staticmethod
|
||||
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
||||
|
@ -63,31 +76,60 @@ class Tacotron2(nn.Module):
|
|||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, speaker_ids=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None):
|
||||
self._init_states()
|
||||
# compute mask for padding
|
||||
mask = sequence_mask(text_lengths).to(text.device)
|
||||
# B x T_in_max (boolean)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
# B x D_embed x T_in_max
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
# B x T_in_max x D_en
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
# adding speaker embeddding to encoder output
|
||||
# TODO: multi-speaker
|
||||
# B x speaker_embed_dim
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
# B x T_in x embed_dim + speaker_embed_dim
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
self.speaker_embeddings)
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
# global style token
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs)
|
||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, mask)
|
||||
encoder_outputs, mel_specs, input_mask)
|
||||
# sequence masking
|
||||
if mel_lengths is not None:
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||
# B x mel_dim x T_out
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = decoder_outputs + postnet_outputs
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
|
||||
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||
decoder_outputs, postnet_outputs, alignments)
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_inference(mel_specs, encoder_outputs, mask)
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, speaker_ids=None):
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
speaker_ids)
|
||||
if speaker_ids is not None:
|
||||
self.compute_speaker_embedding(speaker_ids)
|
||||
if self.num_speakers > 1:
|
||||
encoder_outputs = self._add_speaker_embedding(encoder_outputs,
|
||||
self.speaker_embeddings)
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
|
@ -112,22 +154,16 @@ class Tacotron2(nn.Module):
|
|||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
|
||||
def _backward_inference(self, mel_specs, encoder_outputs, mask):
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||
self.speaker_embeddings_projected)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2)
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _add_speaker_embedding(self, encoder_outputs, speaker_ids):
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)
|
||||
def _speaker_embedding_pass(self, encoder_outputs, speaker_ids):
|
||||
# TODO: multi-speaker
|
||||
# if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
# raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
# if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
|
||||
speaker_embeddings.unsqueeze_(1)
|
||||
speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||
encoder_outputs.size(1),
|
||||
-1)
|
||||
encoder_outputs = encoder_outputs + speaker_embeddings
|
||||
return encoder_outputs
|
||||
# speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0),
|
||||
# encoder_outputs.size(1),
|
||||
# -1)
|
||||
# encoder_outputs = encoder_outputs + speaker_embeddings
|
||||
# return encoder_outputs
|
||||
pass
|
||||
|
|
|
@ -0,0 +1,180 @@
|
|||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class TacotronAbstract(ABC, nn.Module):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
gst=False):
|
||||
""" Abstract Tacotron class """
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.gst = gst
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.double_decoder_consistency = double_decoder_consistency
|
||||
self.ddc_r = ddc_r
|
||||
self.attn_type = attn_type
|
||||
self.attn_win = attn_win
|
||||
self.attn_norm = attn_norm
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.forward_attn_mask = forward_attn_mask
|
||||
self.location_attn = location_attn
|
||||
self.attn_K = attn_K
|
||||
self.separate_stopnet = separate_stopnet
|
||||
|
||||
# layers
|
||||
self.embedding = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.postnet = None
|
||||
|
||||
# global style token
|
||||
if self.gst:
|
||||
self.gst_layer = None
|
||||
|
||||
# model states
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
||||
def _init_states(self):
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
def _init_backward_decoder(self):
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
||||
def _init_coarse_decoder(self):
|
||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||
self.coarse_decoder.r_init = self.ddc_r
|
||||
self.coarse_decoder.set_r(self.ddc_r)
|
||||
|
||||
#############################
|
||||
# CORE FUNCTIONS
|
||||
#############################
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
pass
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_masks(self, text_lengths, mel_lengths):
|
||||
"""Compute masks against sequence paddings."""
|
||||
# B x T_in_max (boolean)
|
||||
device = text_lengths.device
|
||||
input_mask = sequence_mask(text_lengths).to(device)
|
||||
output_mask = None
|
||||
if mel_lengths is not None:
|
||||
max_len = mel_lengths.max()
|
||||
r = self.decoder.r
|
||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device)
|
||||
return input_mask, output_mask
|
||||
|
||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
""" Run backwards decoder """
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask,
|
||||
self.speaker_embeddings_projected)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments,
|
||||
input_mask):
|
||||
""" Double Decoder Consistency """
|
||||
T = mel_specs.shape[1]
|
||||
if T % self.coarse_decoder.r > 0:
|
||||
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
|
||||
mel_specs = torch.nn.functional.pad(mel_specs,
|
||||
(0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2),
|
||||
size=alignments.shape[1],
|
||||
mode='nearest').transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
return decoder_outputs_backward, alignments_backward
|
||||
|
||||
#############################
|
||||
# EMBEDDING FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
""" Compute speaker embedding vectors """
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(
|
||||
" [!] Model has speaker embedding layer but speaker_id is not provided"
|
||||
)
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
||||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||
self.speaker_embeddings).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, mel_specs):
|
||||
""" Compute global style token """
|
||||
# pylint: disable=not-callable
|
||||
gst_outputs = self.gst_layer(mel_specs)
|
||||
inputs = self._add_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + speaker_embeddings_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||
"frame_length_ms": 50, // stft window length in ms.
|
||||
|
@ -31,19 +31,19 @@
|
|||
|
||||
"reinit_layers": [],
|
||||
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
|
||||
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
|
||||
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"forward_attn_mask": false,
|
||||
"forward_attn_mask": false,
|
||||
"attention_type": "original",
|
||||
"attention_heads": 5,
|
||||
"bidirectional_decoder": false,
|
||||
|
@ -51,13 +51,15 @@
|
|||
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"use_gst": false,
|
||||
|
||||
"use_gst": false,
|
||||
"double_decoder_consistency": true, // use DDC explained here https://erogol.com/solving-attention-problems-of-tts-models-with-double-decoder-consistency-draft/
|
||||
"ddc_r": 7, // reduction rate for coarse decoder.
|
||||
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of frames to predict for step.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
|
|
|
@ -51,7 +51,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||
input, input_lengths, mel_spec, speaker_ids)
|
||||
input, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -66,7 +66,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(5):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, speaker_ids)
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
@ -95,6 +95,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device)
|
||||
linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device)
|
||||
mel_lengths = torch.randint(20, 120, (8, )).long().to(device)
|
||||
mel_lengths[-1] = 120
|
||||
stop_targets = torch.zeros(8, 120, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
|
@ -130,7 +131,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(10):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, speaker_ids)
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
|
10
train.py
10
train.py
|
@ -158,13 +158,14 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
optimizer_st.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder:
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
|
||||
# set the alignment lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
|
@ -176,7 +177,8 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
|
|||
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
||||
linear_input, stop_tokens, stop_targets,
|
||||
mel_lengths, decoder_backward_output,
|
||||
alignments, alignment_lengths, text_lengths)
|
||||
alignments, alignment_lengths, alignments_backward,
|
||||
text_lengths)
|
||||
if c.bidirectional_decoder:
|
||||
keep_avg.update_values({'avg_decoder_b_loss': loss_dict['decoder_backward_loss'].item(),
|
||||
'avg_decoder_c_loss': loss_dict['decoder_c_loss'].item()})
|
||||
|
|
|
@ -160,13 +160,16 @@ def setup_model(num_chars, num_speakers, c):
|
|||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
|
@ -178,7 +181,9 @@ def setup_model(num_chars, num_speakers, c):
|
|||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r)
|
||||
return model
|
||||
|
||||
class KeepAverage():
|
||||
|
@ -313,6 +318,8 @@ def check_config(c):
|
|||
_check_argument('transition_agent', c, restricted=True, val_type=bool)
|
||||
_check_argument('location_attn', c, restricted=True, val_type=bool)
|
||||
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
|
||||
_check_argument('double_decoder_consistency', c, restricted=True, val_type=bool)
|
||||
_check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
|
||||
# stopnet
|
||||
_check_argument('stopnet', c, restricted=True, val_type=bool)
|
||||
|
|
|
@ -77,6 +77,7 @@ class MultiScaleSTFTLoss(torch.nn.Module):
|
|||
|
||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||
""" Multiscale STFT loss for multi band model outputs """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, y_hat, y):
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
y = y.view(-1, 1, y.shape[2])
|
||||
|
@ -85,6 +86,7 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
|||
|
||||
class MSEGLoss(nn.Module):
|
||||
""" Mean Squared Generator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake):
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
return loss_fake
|
||||
|
@ -92,6 +94,7 @@ class MSEGLoss(nn.Module):
|
|||
|
||||
class HingeGLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake):
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
return loss_fake
|
||||
|
@ -104,6 +107,7 @@ class HingeGLoss(nn.Module):
|
|||
|
||||
class MSEDLoss(nn.Module):
|
||||
""" Mean Squared Discriminator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(torch.sum(torch.pow(score_real - 1.0, 2), dim=[1, 2]))
|
||||
loss_fake = torch.mean(torch.sum(torch.pow(score_fake, 2), dim=[1, 2]))
|
||||
|
@ -113,6 +117,7 @@ class MSEDLoss(nn.Module):
|
|||
|
||||
class HingeDLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(F.relu(1. - score_real))
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
|
@ -121,6 +126,7 @@ class HingeDLoss(nn.Module):
|
|||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, fake_feats, real_feats):
|
||||
loss_feats = 0
|
||||
for fake_feat, real_feat in zip(fake_feats, real_feats):
|
||||
|
@ -193,8 +199,8 @@ class GeneratorLoss(nn.Module):
|
|||
|
||||
self.stft_loss_weight = C.stft_loss_weight
|
||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight
|
||||
self.mse_gan_loss_weight = C.mse_gan_loss_weight
|
||||
self.hinge_gan_loss_weight = C.hinge_gan_loss_weight
|
||||
self.mse_gan_loss_weight = C.mse_G_loss_weight
|
||||
self.hinge_gan_loss_weight = C.hinge_G_loss_weight
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight
|
||||
|
||||
if C.use_stft_loss:
|
||||
|
|
|
@ -52,7 +52,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=1 if is_val else c.batch_size,
|
||||
shuffle=False,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
sampler=None,
|
||||
num_workers=c.num_val_loader_workers
|
||||
|
@ -120,11 +120,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
y_hat = model_G(c_G)
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
y_hat_vis = y_hat # for visualization
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_hat_vis = y_hat
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
@ -171,7 +173,10 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
loss_dict[key] = value.item()
|
||||
if isinstance(value, int):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
|
@ -265,12 +270,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
model_losses=loss_dict)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat, y_G, ap, global_step,
|
||||
figures = plot_results(y_hat_vis, y_G, ap, global_step,
|
||||
'train')
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'train/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
|
@ -322,8 +327,12 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
D_out_fake = model_D(y_hat)
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(y_G)
|
||||
|
@ -354,7 +363,7 @@ def evaluate(model_G, criterion_G, model_D, ap, global_step, epoch):
|
|||
for key, value in loss_G_dict.items():
|
||||
update_eval_values['avg_' + key] = value.item()
|
||||
update_eval_values['avg_loader_time'] = loader_time
|
||||
update_eval_values['avgP_step_time'] = step_time
|
||||
update_eval_values['avg_step_time'] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
|
|
Loading…
Reference in New Issue