diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 4d1574da..a02ff95a 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -10,8 +10,10 @@ class ConvBNBlock(nn.Module): super(ConvBNBlock, self).__init__() assert (kernel_size - 1) % 2 == 0 padding = (kernel_size - 1) // 2 - conv1d = nn.Conv1d( - in_channels, out_channels, kernel_size, padding=padding) + conv1d = nn.Conv1d(in_channels, + out_channels, + kernel_size, + padding=padding) norm = nn.BatchNorm1d(out_channels) dropout = nn.Dropout(p=0.5) if nonlinear == 'relu': @@ -52,20 +54,20 @@ class Encoder(nn.Module): convolutions.append( ConvBNBlock(in_features, in_features, 5, 'relu')) self.convolutions = nn.Sequential(*convolutions) - self.lstm = nn.LSTM( - in_features, - int(in_features / 2), - num_layers=1, - batch_first=True, - bidirectional=True) + self.lstm = nn.LSTM(in_features, + int(in_features / 2), + num_layers=1, + batch_first=True, + bidirectional=True) self.rnn_state = None def forward(self, x, input_lengths): x = self.convolutions(x) x = x.transpose(1, 2) input_lengths = input_lengths.cpu().numpy() - x = nn.utils.rnn.pack_padded_sequence( - x, input_lengths, batch_first=True) + x = nn.utils.rnn.pack_padded_sequence(x, + input_lengths, + batch_first=True) self.lstm.flatten_parameters() outputs, _ = self.lstm(x) outputs, _ = nn.utils.rnn.pad_packed_sequence( @@ -112,9 +114,11 @@ class Decoder(nn.Module): self.gate_threshold = 0.5 self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 - self.prenet = Prenet(self.mel_channels, prenet_type, + self.prenet = Prenet(self.mel_channels, + prenet_type, prenet_dropout, - [self.prenet_dim, self.prenet_dim], bias=False) + [self.prenet_dim, self.prenet_dim], + bias=False) self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.query_dim) @@ -139,19 +143,20 @@ class Decoder(nn.Module): self.stopnet = nn.Sequential( nn.Dropout(0.1), - Linear( - self.decoder_rnn_dim + self.mel_channels * self.r_init, - 1, - bias=True, - init_gain='sigmoid')) + Linear(self.decoder_rnn_dim + self.mel_channels * self.r_init, + 1, + bias=True, + init_gain='sigmoid')) self.memory_truncated = None def set_r(self, new_r): self.r = new_r - + def get_go_frame(self, inputs): B = inputs.size(0) - memory = torch.zeros(B, self.mel_channels * self.r, device=inputs.device) + memory = torch.zeros(B, + self.mel_channels * self.r, + device=inputs.device) return memory def _init_states(self, inputs, mask, keep_states=False): @@ -159,17 +164,25 @@ class Decoder(nn.Module): # T = inputs.size(1) if not keep_states: self.query = torch.zeros(B, self.query_dim, device=inputs.device) - self.attention_rnn_cell_state = torch.zeros(B, self.query_dim, device=inputs.device) - self.decoder_hidden = torch.zeros(B, self.decoder_rnn_dim, device=inputs.device) - self.decoder_cell = torch.zeros(B, self.decoder_rnn_dim, device=inputs.device) - self.context = torch.zeros(B, self.encoder_embedding_dim, device=inputs.device) + self.attention_rnn_cell_state = torch.zeros(B, + self.query_dim, + device=inputs.device) + self.decoder_hidden = torch.zeros(B, + self.decoder_rnn_dim, + device=inputs.device) + self.decoder_cell = torch.zeros(B, + self.decoder_rnn_dim, + device=inputs.device) + self.context = torch.zeros(B, + self.encoder_embedding_dim, + device=inputs.device) self.inputs = inputs self.processed_inputs = self.attention.inputs_layer(inputs) self.mask = mask def _reshape_memory(self, memories): - memories = memories.view( - memories.size(0), int(memories.size(1) / self.r), -1) + memories = memories.view(memories.size(0), + int(memories.size(1) / self.r), -1) memories = memories.transpose(0, 1) return memories @@ -184,18 +197,18 @@ class Decoder(nn.Module): def _update_memory(self, memory): if len(memory.shape) == 2: - return memory[:, self.mel_channels * (self.r - 1) :] - else: - return memory[:, :, self.mel_channels * (self.r - 1) :] + return memory[:, self.mel_channels * (self.r - 1):] + return memory[:, :, self.mel_channels * (self.r - 1):] def decode(self, memory): query_input = torch.cat((memory, self.context), -1) self.query, self.attention_rnn_cell_state = self.attention_rnn( query_input, (self.query, self.attention_rnn_cell_state)) - self.query = F.dropout( - self.query, self.p_attention_dropout, self.training) + self.query = F.dropout(self.query, self.p_attention_dropout, + self.training) self.attention_rnn_cell_state = F.dropout( - self.attention_rnn_cell_state, self.p_attention_dropout, self.training) + self.attention_rnn_cell_state, self.p_attention_dropout, + self.training) self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask) diff --git a/models/tacotron.py b/models/tacotron.py index 8f40f313..8f711364 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -3,6 +3,7 @@ import torch from torch import nn from TTS.layers.tacotron import Encoder, Decoder, PostCBHG from TTS.utils.generic_utils import sequence_mask +from TTS.layers.gst_layers import GST class Tacotron(nn.Module): @@ -14,6 +15,7 @@ class Tacotron(nn.Module): mel_dim=80, memory_size=5, attn_win=False, + gst=False, attn_norm="sigmoid", prenet_type="original", prenet_dropout=True, @@ -26,35 +28,59 @@ class Tacotron(nn.Module): self.r = r self.mel_dim = mel_dim self.linear_dim = linear_dim + self.gst = gst self.num_speakers = num_speakers self.embedding = nn.Embedding(num_chars, 256) self.embedding.weight.data.normal_(0, 0.3) decoder_dim = 512 if num_speakers > 1 else 256 encoder_dim = 512 if num_speakers > 1 else 256 proj_speaker_dim = 80 if num_speakers > 1 else 0 - if num_speakers > 1: - self.speaker_embedding = nn.Embedding(num_speakers, 256) - self.speaker_embedding.weight.data.normal_(0, 0.3) - self.speaker_project_mel = nn.Sequential(nn.Linear(256, proj_speaker_dim), nn.Tanh()) + # boilerplate model self.encoder = Encoder(encoder_dim) self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, - location_attn, separate_stopnet, proj_speaker_dim) + location_attn, separate_stopnet, + proj_speaker_dim) self.postnet = PostCBHG(mel_dim) - self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, + linear_dim) + # speaker embedding layers + if num_speakers > 1: + self.speaker_embedding = nn.Embedding(num_speakers, 256) + self.speaker_embedding.weight.data.normal_(0, 0.3) + self.speaker_project_mel = nn.Sequential( + nn.Linear(256, proj_speaker_dim), nn.Tanh()) + self.speaker_embeddings = None + self.speaker_embeddings_projected = None + # global style token layers + if self.gst: + gst_embedding_dim = 256 + self.gst_layer = GST(num_mel=80, + num_heads=4, + num_style_tokens=10, + embedding_dim=gst_embedding_dim) def _init_states(self): - self.speaker_embeddings = None + self.speaker_embeddings = None self.speaker_embeddings_projected = None def compute_speaker_embedding(self, speaker_ids): if hasattr(self, "speaker_embedding") and speaker_ids is None: - raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") + raise RuntimeError( + " [!] Model has speaker embedding layer but speaker_id is not provided" + ) if hasattr(self, "speaker_embedding") and speaker_ids is not None: - self.speaker_embeddings = self._compute_speaker_embedding(speaker_ids) - self.speaker_embeddings_projected = self.speaker_project_mel(self.speaker_embeddings).squeeze(1) - + self.speaker_embeddings = self._compute_speaker_embedding( + speaker_ids) + self.speaker_embeddings_projected = self.speaker_project_mel( + self.speaker_embeddings).squeeze(1) + + def compute_gst(self, inputs, mel_specs): + gst_outputs = self.gst_layer(mel_specs) + inputs = self._add_speaker_embedding(inputs, gst_outputs) + return inputs + def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) @@ -63,30 +89,35 @@ class Tacotron(nn.Module): self.compute_speaker_embedding(speaker_ids) if self.num_speakers > 1: inputs = self._concat_speaker_embedding(inputs, - self.speaker_embeddings) + self.speaker_embeddings) encoder_outputs = self.encoder(inputs) + if self.gst: + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs) if self.num_speakers > 1: - encoder_outputs = self._concat_speaker_embedding(encoder_outputs, - self.speaker_embeddings) + encoder_outputs = self._concat_speaker_embedding( + encoder_outputs, self.speaker_embeddings) mel_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, mask, self.speaker_embeddings_projected) + encoder_outputs, mel_specs, mask, + self.speaker_embeddings_projected) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments, stop_tokens - def inference(self, characters, speaker_ids=None): + def inference(self, characters, speaker_ids=None, style_mel=None): B = characters.size(0) inputs = self.embedding(characters) self._init_states() self.compute_speaker_embedding(speaker_ids) if self.num_speakers > 1: inputs = self._concat_speaker_embedding(inputs, - self.speaker_embeddings) + self.speaker_embeddings) encoder_outputs = self.encoder(inputs) + if self.gst and style_mel is not None: + encoder_outputs = self.compute_gst(encoder_outputs, style_mel) if self.num_speakers > 1: - encoder_outputs = self._concat_speaker_embedding(encoder_outputs, - self.speaker_embeddings) + encoder_outputs = self._concat_speaker_embedding( + encoder_outputs, self.speaker_embeddings) mel_outputs, alignments, stop_tokens = self.decoder.inference( encoder_outputs, self.speaker_embeddings_projected) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) @@ -98,16 +129,16 @@ class Tacotron(nn.Module): speaker_embeddings = self.speaker_embedding(speaker_ids) return speaker_embeddings.unsqueeze_(1) - def _add_speaker_embedding(self, outputs, speaker_embeddings): - speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), - outputs.size(1), - -1) + @staticmethod + def _add_speaker_embedding(outputs, speaker_embeddings): + speaker_embeddings_ = speaker_embeddings.expand( + outputs.size(0), outputs.size(1), -1) outputs = outputs + speaker_embeddings_ return outputs - def _concat_speaker_embedding(self, outputs, speaker_embeddings): - speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), - outputs.size(1), - -1) + @staticmethod + def _concat_speaker_embedding(outputs, speaker_embeddings): + speaker_embeddings_ = speaker_embeddings.expand( + outputs.size(0), outputs.size(1), -1) outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) return outputs diff --git a/models/tacotrongst.py b/models/tacotrongst.py deleted file mode 100644 index 9819ec53..00000000 --- a/models/tacotrongst.py +++ /dev/null @@ -1,97 +0,0 @@ -# coding: utf-8 -import torch -from torch import nn -from TTS.layers.tacotron import Encoder, Decoder, PostCBHG -from TTS.layers.gst_layers import GST -from TTS.utils.generic_utils import sequence_mask -from TTS.models.tacotron import Tacotron - - -class TacotronGST(Tacotron): - def __init__(self, - num_chars, - num_speakers, - r=5, - linear_dim=1025, - mel_dim=80, - memory_size=5, - attn_win=False, - attn_norm="sigmoid", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - separate_stopnet=True): - super().__init__(num_chars, - num_speakers, - r, - linear_dim, - mel_dim, - memory_size, - attn_win, - attn_norm, - prenet_type, - prenet_dropout, - forward_attn, - trans_agent, - forward_attn_mask, - location_attn, - separate_stopnet) - gst_embedding_dim = 256 - decoder_dim = 512 if num_speakers > 1 else 256 - proj_speaker_dim = 80 if num_speakers > 1 else 0 - self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, separate_stopnet, proj_speaker_dim) - self.gst = GST(num_mel=80, num_heads=4, - num_style_tokens=10, embedding_dim=gst_embedding_dim) - - def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): - B = characters.size(0) - mask = sequence_mask(text_lengths).to(characters.device) - inputs = self.embedding(characters) - self._init_states() - self.compute_speaker_embedding(speaker_ids) - if self.num_speakers > 1: - inputs = self._add_speaker_embedding(inputs, - self.speaker_embeddings) - encoder_outputs = self.encoder(inputs) - if self.num_speakers > 1: - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - self.speaker_embeddings) - gst_outputs = self.gst(mel_specs) - encoder_outputs = self._add_speaker_embedding( - encoder_outputs, gst_outputs) - mel_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, mask, self.speaker_embeddings_projected) - mel_outputs = mel_outputs.view(B, -1, self.mel_dim) - linear_outputs = self.postnet(mel_outputs) - linear_outputs = self.last_linear(linear_outputs) - return mel_outputs, linear_outputs, alignments, stop_tokens - - def inference(self, characters, speaker_ids=None, style_mel=None): - B = characters.size(0) - inputs = self.embedding(characters) - self._init_states() - self.compute_speaker_embedding(speaker_ids) - if self.num_speakers > 1: - inputs = self._add_speaker_embedding(inputs, - self.speaker_embeddings) - encoder_outputs = self.encoder(inputs) - if self.num_speakers > 1: - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - self.speaker_embeddings) - if style_mel is not None: - gst_outputs = self.gst(style_mel) - gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - gst_outputs) - mel_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs, self.speaker_embeddings_projected) - mel_outputs = mel_outputs.view(B, -1, self.mel_dim) - linear_outputs = self.postnet(mel_outputs) - linear_outputs = self.last_linear(linear_outputs) - return mel_outputs, linear_outputs, alignments, stop_tokens diff --git a/tests/test_layers.py b/tests/test_layers.py index a465a898..6b5fd80b 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -67,7 +67,8 @@ class DecoderTests(unittest.TestCase): assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2]) assert stop_tokens.shape[0] == 4 - def test_in_out_multispeaker(self): + @staticmethod + def test_in_out_multispeaker(): layer = Decoder( in_features=256, memory_dim=80, diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index 9b8de336..c8b0d7ca 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -8,7 +8,6 @@ from torch import nn from TTS.utils.generic_utils import load_config from TTS.layers.losses import L1LossMasked from TTS.models.tacotron import Tacotron -from TTS.models.tacotrongst import TacotronGST #pylint: disable=unused-variable @@ -25,68 +24,72 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -# class TacotronTrainTest(unittest.TestCase): - # def test_train_step(self): - # input = torch.randint(0, 24, (8, 128)).long().to(device) - # input_lengths = torch.randint(100, 129, (8, )).long().to(device) - # input_lengths[-1] = 128 - # mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - # linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) - # mel_lengths = torch.randint(20, 30, (8, )).long().to(device) - # stop_targets = torch.zeros(8, 30, 1).float().to(device) - # speaker_ids = torch.randint(0, 5, (8, )).long().to(device) +class TacotronTrainTest(unittest.TestCase): + @staticmethod + def test_train_step(): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (8, )).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) - # for idx in mel_lengths: - # stop_targets[:, int(idx.item()):, 0] = 1.0 + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 - # stop_targets = stop_targets.view(input.shape[0], - # stop_targets.size(1) // c.r, -1) - # stop_targets = (stop_targets.sum(2) > - # 0.0).unsqueeze(2).float().squeeze() + stop_targets = stop_targets.view(input_dummy.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > + 0.0).unsqueeze(2).float().squeeze() - # criterion = L1LossMasked().to(device) - # criterion_st = nn.BCEWithLogitsLoss().to(device) - # model = Tacotron( - # num_chars=32, - # num_speakers=5, - # linear_dim=c.audio['num_freq'], - # mel_dim=c.audio['num_mels'], - # r=c.r, - # memory_size=c.memory_size).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor - # model.train() - # print(" > Num parameters for Tacotron model:%s"%(count_parameters(model))) - # model_ref = copy.deepcopy(model) - # count = 0 - # for param, param_ref in zip(model.parameters(), - # model_ref.parameters()): - # assert (param - param_ref).sum() == 0, param - # count += 1 - # optimizer = optim.Adam(model.parameters(), lr=c.lr) - # for _ in range(5): - # mel_out, linear_out, align, stop_tokens = model.forward( - # input, input_lengths, mel_spec, speaker_ids) - # optimizer.zero_grad() - # loss = criterion(mel_out, mel_spec, mel_lengths) - # stop_loss = criterion_st(stop_tokens, stop_targets) - # loss = loss + criterion(linear_out, linear_spec, - # mel_lengths) + stop_loss - # loss.backward() - # optimizer.step() - # # check parameter changes - # count = 0 - # for param, param_ref in zip(model.parameters(), - # model_ref.parameters()): - # # ignore pre-higway layer since it works conditional - # # if count not in [145, 59]: - # assert (param != param_ref).any( - # ), "param {} with shape {} not updated!! \n{}\n{}".format( - # count, param.shape, param, param_ref) - # count += 1 + criterion = L1LossMasked().to(device) + criterion_st = nn.BCEWithLogitsLoss().to(device) + model = Tacotron( + num_chars=32, + num_speakers=5, + linear_dim=c.audio['num_freq'], + mel_dim=c.audio['num_mels'], + r=c.r, + memory_size=c.memory_size + ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + model.train() + print(" > Num parameters for Tacotron model:%s" % + (count_parameters(model))) + model_ref = copy.deepcopy(model) + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + assert (param - param_ref).sum() == 0, param + count += 1 + optimizer = optim.Adam(model.parameters(), lr=c.lr) + for _ in range(5): + mel_out, linear_out, align, stop_tokens = model.forward( + input_dummy, input_lengths, mel_spec, speaker_ids) + optimizer.zero_grad() + loss = criterion(mel_out, mel_spec, mel_lengths) + stop_loss = criterion_st(stop_tokens, stop_targets) + loss = loss + criterion(linear_out, linear_spec, + mel_lengths) + stop_loss + loss.backward() + optimizer.step() + # check parameter changes + count = 0 + for param, param_ref in zip(model.parameters(), + model_ref.parameters()): + # ignore pre-higway layer since it works conditional + # if count not in [145, 59]: + assert (param != param_ref).any( + ), "param {} with shape {} not updated!! \n{}\n{}".format( + count, param.shape, param, param_ref) + count += 1 class TacotronGSTTrainTest(unittest.TestCase): - def test_train_step(self): - input = torch.randint(0, 24, (8, 128)).long().to(device) + @staticmethod + def test_train_step(): + input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device) @@ -98,23 +101,26 @@ class TacotronGSTTrainTest(unittest.TestCase): for idx in mel_lengths: stop_targets[:, int(idx.item()):, 0] = 1.0 - stop_targets = stop_targets.view(input.shape[0], + stop_targets = stop_targets.view(input_dummy.shape[0], stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() criterion = L1LossMasked().to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) - model = TacotronGST( + model = Tacotron( num_chars=32, - num_speakers=5, + num_speakers=5, + gst=True, linear_dim=c.audio['num_freq'], mel_dim=c.audio['num_mels'], r=c.r, - memory_size=c.memory_size).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + memory_size=c.memory_size + ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor model.train() print(model) - print(" > Num parameters for Tacotron GST model:%s"%(count_parameters(model))) + print(" > Num parameters for Tacotron GST model:%s" % + (count_parameters(model))) model_ref = copy.deepcopy(model) count = 0 for param, param_ref in zip(model.parameters(), @@ -124,7 +130,7 @@ class TacotronGSTTrainTest(unittest.TestCase): optimizer = optim.Adam(model.parameters(), lr=c.lr) for _ in range(10): mel_out, linear_out, align, stop_tokens = model.forward( - input, input_lengths, mel_spec, speaker_ids) + input_dummy, input_lengths, mel_spec, speaker_ids) optimizer.zero_grad() loss = criterion(mel_out, mel_spec, mel_lengths) stop_loss = criterion_st(stop_tokens, stop_targets)