diff --git a/layers/tacotron.py b/layers/tacotron.py index 3d505ddb..931388bf 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -302,13 +302,14 @@ class Decoder(nn.Module): """ def __init__(self, in_features, memory_dim, r, memory_size, - attn_windowing, attn_norm): + attn_windowing, attn_norm, separate_stopnet): super(Decoder, self).__init__() self.r = r self.in_features = in_features self.max_decoder_steps = 500 self.memory_size = memory_size if memory_size > 0 else r self.memory_dim = memory_dim + self.separate_stopnet = separate_stopnet # memory -> |Prenet| -> processed_memory self.prenet = Prenet( memory_dim * self.memory_size, out_features=[256, 128]) @@ -415,7 +416,10 @@ class Decoder(nn.Module): # predict stop token stopnet_input = torch.cat([decoder_output, output], -1) del decoder_output - stop_token = self.stopnet(stopnet_input) + if self.separate_stopnet: + stop_token = self.stopnet(stopnet_input.detach()) + else: + stop_token = self.stopnet(stopnet_input) return output, stop_token, self.attention def _update_memory_queue(self, new_memory): diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 6bb08ce1..32fd5ace 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -44,7 +44,7 @@ class LinearBN(nn.Module): def forward(self, x): out = self.linear_layer(x) - if len(out.shape)==3: + if len(out.shape) == 3: out = out.permute(1, 2, 0) out = self.bn(out) if len(out.shape) == 3: @@ -53,7 +53,11 @@ class LinearBN(nn.Module): class Prenet(nn.Module): - def __init__(self, in_features, prenet_type, prenet_dropout, out_features=[256, 256]): + def __init__(self, + in_features, + prenet_type, + prenet_dropout, + out_features=[256, 256]): super(Prenet, self).__init__() self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout @@ -64,8 +68,8 @@ class Prenet(nn.Module): for (in_size, out_size) in zip(in_features, out_features) ]) elif prenet_type == "original": - self.layers = nn.ModuleList( - [Linear(in_size, out_size, bias=False) + self.layers = nn.ModuleList([ + Linear(in_size, out_size, bias=False) for (in_size, out_size) in zip(in_features, out_features) ]) @@ -76,7 +80,7 @@ class Prenet(nn.Module): else: x = F.relu(linear(x)) return x - + class ConvBNBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None): @@ -121,9 +125,10 @@ class LocationLayer(nn.Module): class Attention(nn.Module): - def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, location_attention, - attention_location_n_filters, attention_location_kernel_size, - windowing, norm, forward_attn, trans_agent): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + location_attention, attention_location_n_filters, + attention_location_kernel_size, windowing, norm, forward_attn, + trans_agent): super(Attention, self).__init__() self.query_layer = Linear( attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') @@ -131,11 +136,12 @@ class Attention(nn.Module): embedding_dim, attention_dim, bias=False, init_gain='tanh') self.v = Linear(attention_dim, 1, bias=True) if trans_agent: - self.ta = nn.Linear(attention_rnn_dim + embedding_dim, 1, bias=True) + self.ta = nn.Linear( + attention_rnn_dim + embedding_dim, 1, bias=True) if location_attention: - self.location_layer = LocationLayer(attention_location_n_filters, - attention_location_kernel_size, - attention_dim) + self.location_layer = LocationLayer( + attention_location_n_filters, attention_location_kernel_size, + attention_dim) self._mask_value = -float("inf") self.windowing = windowing self.win_idx = None @@ -148,11 +154,13 @@ class Attention(nn.Module): self.win_idx = -1 self.win_back = 2 self.win_front = 6 - + def init_forward_attn(self, inputs): B = inputs.shape[0] T = inputs.shape[1] - self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7 ], dim=1).to(inputs.device) + self.alpha = torch.cat( + [torch.ones([B, 1]), + torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) def init_location_attention(self, inputs): @@ -182,14 +190,13 @@ class Attention(nn.Module): processed_attention_weights = self.location_layer(attention_cat) energies = self.v( torch.tanh(processed_query + processed_attention_weights + - processed_inputs)) + processed_inputs)) energies = energies.squeeze(-1) return energies, processed_query def get_attention(self, query, processed_inputs): processed_query = self.query_layer(query.unsqueeze(1)) - energies = self.v( - torch.tanh(processed_query +processed_inputs)) + energies = self.v(torch.tanh(processed_query + processed_inputs)) energies = energies.squeeze(-1) return energies, processed_query @@ -210,8 +217,10 @@ class Attention(nn.Module): def apply_forward_attention(self, inputs, alignment, query): # forward attention - prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device) - alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment + prev_alpha = F.pad(self.alpha[:, :-1].clone(), + (1, 0, 0, 0)).to(inputs.device) + alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + + self.u * prev_alpha) + 1e-8) * alignment self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) # compute context context = torch.bmm(self.alpha.unsqueeze(1), inputs) @@ -222,8 +231,7 @@ class Attention(nn.Module): self.u = torch.sigmoid(self.ta(ta_input)) return context, self.alpha - def forward(self, attention_hidden_state, inputs, processed_inputs, - mask): + def forward(self, attention_hidden_state, inputs, processed_inputs, mask): if self.location_attention: attention, processed_query = self.get_location_attention( attention_hidden_state, processed_inputs) @@ -241,14 +249,15 @@ class Attention(nn.Module): alignment = torch.softmax(attention, dim=-1) elif self.norm == "sigmoid": alignment = torch.sigmoid(attention) / torch.sigmoid( - attention).sum(dim=1).unsqueeze(1) + attention).sum(dim=1).unsqueeze(1) else: raise RuntimeError("Unknown value for attention norm type") if self.location_attention: self.update_location_attention(alignment) # apply forward attention if enabled if self.forward_attn: - context, self.attention_weights = self.apply_forward_attention(inputs, alignment, attention_hidden_state) + context, self.attention_weights = self.apply_forward_attention( + inputs, alignment, attention_hidden_state) else: context = torch.bmm(alignment.unsqueeze(1), inputs) context = context.squeeze(1) @@ -321,13 +330,17 @@ class Encoder(nn.Module): outputs, self.rnn_state = self.lstm(x, self.rnn_state) return outputs + # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): - def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn): + def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, + prenet_type, prenet_dropout, forward_attn, trans_agent, + location_attn, separate_stopnet): super(Decoder, self).__init__() self.mel_channels = inputs_dim self.r = r self.encoder_embedding_dim = in_features + self.separate_stopnet = separate_stopnet self.attention_rnn_dim = 1024 self.decoder_rnn_dim = 1024 self.prenet_dim = 256 @@ -336,14 +349,16 @@ class Decoder(nn.Module): self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 - self.prenet = Prenet(self.mel_channels * r, prenet_type, prenet_dropout, + self.prenet = Prenet(self.mel_channels * r, prenet_type, + prenet_dropout, [self.prenet_dim, self.prenet_dim]) self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn_dim) - self.attention_layer = Attention(self.attention_rnn_dim, in_features, 128, location_attn, - 32, 31, attn_win, attn_norm, forward_attn, trans_agent) + self.attention_layer = Attention(self.attention_rnn_dim, in_features, + 128, location_attn, 32, 31, attn_win, + attn_norm, forward_attn, trans_agent) self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn_dim, 1) @@ -353,10 +368,11 @@ class Decoder(nn.Module): self.stopnet = nn.Sequential( nn.Dropout(0.1), - Linear(self.decoder_rnn_dim + self.mel_channels * r, - 1, - bias=True, - init_gain='sigmoid')) + Linear( + self.decoder_rnn_dim + self.mel_channels * r, + 1, + bias=True, + init_gain='sigmoid')) self.attention_rnn_init = nn.Embedding(1, self.attention_rnn_dim) self.go_frame_init = nn.Embedding(1, self.mel_channels * r) @@ -382,10 +398,10 @@ class Decoder(nn.Module): inputs.data.new_zeros(B).long()) self.decoder_cell = Variable( inputs.data.new(B, self.decoder_rnn_dim).zero_()) - + self.context = Variable( - inputs.data.new(B, self.encoder_embedding_dim).zero_()) - + inputs.data.new(B, self.encoder_embedding_dim).zero_()) + self.inputs = inputs self.processed_inputs = self.attention_layer.inputs_layer(inputs) self.mask = mask @@ -401,8 +417,7 @@ class Decoder(nn.Module): stop_tokens = torch.stack(stop_tokens).transpose(0, 1) stop_tokens = stop_tokens.contiguous() outputs = torch.stack(outputs).transpose(0, 1).contiguous() - outputs = outputs.view( - outputs.size(0), -1, self.mel_channels) + outputs = outputs.view(outputs.size(0), -1, self.mel_channels) outputs = outputs.transpose(1, 2) return outputs, stop_tokens, alignments @@ -415,12 +430,10 @@ class Decoder(nn.Module): self.attention_cell = F.dropout( self.attention_cell, self.p_attention_dropout, self.training) - self.context = self.attention_layer( - self.attention_hidden, self.inputs, self.processed_inputs, - self.mask) + self.context = self.attention_layer(self.attention_hidden, self.inputs, + self.processed_inputs, self.mask) - memory = torch.cat( - (self.attention_hidden, self.context), -1) + memory = torch.cat((self.attention_hidden, self.context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( memory, (self.decoder_hidden, self.decoder_cell)) self.decoder_hidden = F.dropout(self.decoder_hidden, @@ -428,16 +441,18 @@ class Decoder(nn.Module): self.decoder_cell = F.dropout(self.decoder_cell, self.p_decoder_dropout, self.training) - decoder_hidden_context = torch.cat( - (self.decoder_hidden, self.context), dim=1) + decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), + dim=1) - decoder_output = self.linear_projection( - decoder_hidden_context) + decoder_output = self.linear_projection(decoder_hidden_context) stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) - gate_prediction = self.stopnet(stopnet_input) - return decoder_output, gate_prediction, self.attention_layer.attention_weights + if self.separate_stopnet: + stop_token = self.stopnet(stopnet_input.detach()) + else: + stop_token = self.stopnet(stopnet_input) + return decoder_output, stop_token, self.attention_layer.attention_weights def forward(self, inputs, memories, mask): memory = self.get_go_frame(inputs).unsqueeze(0) @@ -451,8 +466,7 @@ class Decoder(nn.Module): outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: memory = memories[len(outputs)] - mel_output, stop_token, attention_weights = self.decode( - memory) + mel_output, stop_token, attention_weights = self.decode(memory) outputs += [mel_output.squeeze(1)] stop_tokens += [stop_token.squeeze(1)] alignments += [attention_weights] @@ -481,7 +495,8 @@ class Decoder(nn.Module): alignments += [alignment] stop_flags[0] = stop_flags[0] or stop_token > 0.5 - stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1]) + stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 + and t > inputs.shape[1]) stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): stop_count += 1 @@ -523,7 +538,8 @@ class Decoder(nn.Module): alignments += [alignment] stop_flags[0] = stop_flags[0] or stop_token > 0.5 - stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1]) + stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 + and t > inputs.shape[1]) stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): stop_count += 1 @@ -541,7 +557,6 @@ class Decoder(nn.Module): return outputs, stop_tokens, alignments - def inference_step(self, inputs, t, memory=None): """ For debug purposes diff --git a/models/tacotron.py b/models/tacotron.py index adf74ab3..4243a039 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -15,7 +15,8 @@ class Tacotron(nn.Module): padding_idx=None, memory_size=5, attn_win=False, - attn_norm="sigmoid"): + attn_norm="sigmoid", + separate_stopnet=True): super(Tacotron, self).__init__() self.r = r self.mel_dim = mel_dim @@ -23,7 +24,8 @@ class Tacotron(nn.Module): self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx) self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(256) - self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, attn_norm) + self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, + attn_norm, separate_stopnet) self.postnet = PostCBHG(mel_dim) self.last_linear = nn.Sequential( nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), diff --git a/models/tacotron2.py b/models/tacotron2.py index e9ce1a1b..da4894fc 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -9,7 +9,17 @@ from utils.generic_utils import sequence_mask # TODO: match function arguments with tacotron class Tacotron2(nn.Module): - def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", prenet_dropout=True, forward_attn=False, trans_agent=False, location_attn=True): + def __init__(self, + num_chars, + r, + attn_win=False, + attn_norm="softmax", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + location_attn=True, + separate_stopnet=True): super(Tacotron2, self).__init__() self.n_mel_channels = 80 self.n_frames_per_step = r @@ -18,7 +28,10 @@ class Tacotron2(nn.Module): val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(512) - self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn) + self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, + attn_norm, prenet_type, prenet_dropout, + forward_attn, trans_agent, location_attn, + separate_stopnet) self.postnet = Postnet(self.n_mel_channels) def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): @@ -50,14 +63,14 @@ class Tacotron2(nn.Module): mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens - def inference_truncated(self, text): """ Preserve model states for continuous inference """ embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference_truncated(embedded_inputs) - mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated(encoder_outputs) + mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated( + encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index d991cee9..6dedd943 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -57,6 +57,15 @@ def test_phoneme_to_sequence(): print(len(sequence)) assert text_hat == gt + # padding char + text = "_Be a _voice, not an! echo_" + sequence = phoneme_to_sequence(text, text_cleaner, lang) + text_hat = sequence_to_phoneme(sequence) + gt = "biː ɐ vɔɪs, nɑːt ɐn! ɛkoʊ" + print(text_hat) + print(len(sequence)) + assert text_hat == gt + def test_text2phone(): text = "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase, the grey matter in the parts of the brain responsible for emotional regulation and learning!" diff --git a/train.py b/train.py index f904b2fd..b844024a 100644 --- a/train.py +++ b/train.py @@ -141,11 +141,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, if not c.separate_stopnet and c.stopnet: loss += stop_loss - # backpass and check the grad norm for spec losses - if c.separate_stopnet: - loss.backward(retain_graph=True) - else: - loss.backward() + loss.backward() optimizer, current_lr = weight_decay(optimizer, c.wd) grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 760cbae3..1158ea06 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -253,7 +253,8 @@ def setup_model(num_chars, c): r=c.r, attn_win=c.windowing, attn_norm=c.attention_norm, - memory_size=c.memory_size) + memory_size=c.memory_size, + separate_stopnet=c.separate_stopnet) elif c.model.lower() == "tacotron2": model = MyModel( num_chars=num_chars, @@ -264,5 +265,6 @@ def setup_model(num_chars, c): prenet_dropout=c.prenet_dropout, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, - location_attn=c.location_attn) + location_attn=c.location_attn, + separate_stopnet=c.separate_stopnet) return model \ No newline at end of file diff --git a/utils/text/__init__.py b/utils/text/__init__.py index a05f100a..24433cf3 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -136,8 +136,8 @@ def _arpabet_to_sequence(text): def _should_keep_symbol(s): - return s in _symbol_to_id and s is not '_' and s is not '~' + return s in _symbol_to_id and s not in ['~', '^', '_'] def _should_keep_phoneme(p): - return p in _phonemes_to_id and p is not '_' and p is not '~' + return p in _phonemes_to_id and p not in ['~', '^', '_']