From d282222553c97821f1028495fb7161d2f60b491d Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 28 Apr 2020 18:16:02 +0200 Subject: [PATCH 01/16] renaming layers to be converted to TF counterpart --- layers/common_layers.py | 14 ++-- layers/tacotron2.py | 148 +++++++++++++++++++--------------------- models/tacotron2.py | 2 +- 3 files changed, 80 insertions(+), 84 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index 8b7ed125..d2afe012 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -33,7 +33,7 @@ class LinearBN(nn.Module): super(LinearBN, self).__init__() self.linear_layer = torch.nn.Linear( in_features, out_features, bias=bias) - self.bn = nn.BatchNorm1d(out_features) + self.batch_normalization = nn.BatchNorm1d(out_features) self._init_w(init_gain) def _init_w(self, init_gain): @@ -45,7 +45,7 @@ class LinearBN(nn.Module): out = self.linear_layer(x) if len(out.shape) == 3: out = out.permute(1, 2, 0) - out = self.bn(out) + out = self.batch_normalization(out) if len(out.shape) == 3: out = out.permute(2, 0, 1) return out @@ -63,18 +63,18 @@ class Prenet(nn.Module): self.prenet_dropout = prenet_dropout in_features = [in_features] + out_features[:-1] if prenet_type == "bn": - self.layers = nn.ModuleList([ + self.linear_layers = nn.ModuleList([ LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features) ]) elif prenet_type == "original": - self.layers = nn.ModuleList([ + self.linear_layers = nn.ModuleList([ Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features) ]) def forward(self, x): - for linear in self.layers: + for linear in self.linear_layers: if self.prenet_dropout: x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) else: @@ -93,7 +93,7 @@ class LocationLayer(nn.Module): attention_n_filters=32, attention_kernel_size=31): super(LocationLayer, self).__init__() - self.location_conv = nn.Conv1d( + self.location_conv1d = nn.Conv1d( in_channels=2, out_channels=attention_n_filters, kernel_size=attention_kernel_size, @@ -104,7 +104,7 @@ class LocationLayer(nn.Module): attention_n_filters, attention_dim, bias=False, init_gain='tanh') def forward(self, attention_cat): - processed_attention = self.location_conv(attention_cat) + processed_attention = self.location_conv1d(attention_cat) processed_attention = self.location_dense( processed_attention.transpose(1, 2)) return processed_attention diff --git a/layers/tacotron2.py b/layers/tacotron2.py index fa76a6b2..3e439b9b 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -6,130 +6,126 @@ from .common_layers import init_attn, Prenet, Linear class ConvBNBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, nonlinear=None): + def __init__(self, in_channels, out_channels, kernel_size, activation=None): super(ConvBNBlock, self).__init__() assert (kernel_size - 1) % 2 == 0 padding = (kernel_size - 1) // 2 - conv1d = nn.Conv1d(in_channels, + self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) - norm = nn.BatchNorm1d(out_channels) - dropout = nn.Dropout(p=0.5) - if nonlinear == 'relu': - self.net = nn.Sequential(conv1d, norm, nn.ReLU(), dropout) - elif nonlinear == 'tanh': - self.net = nn.Sequential(conv1d, norm, nn.Tanh(), dropout) + self.batch_normalization = nn.BatchNorm1d(out_channels) + self.dropout = nn.Dropout(p=0.5) + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'tanh': + self.activation = nn.Tanh() else: - self.net = nn.Sequential(conv1d, norm, dropout) + self.activation = nn.Identity() def forward(self, x): - output = self.net(x) - return output + o = self.convolution1d(x) + o = self.batch_normalization(o) + o = self.activation(o) + o = self.dropout(o) + return o class Postnet(nn.Module): - def __init__(self, mel_dim, num_convs=5): + def __init__(self, output_dim, num_convs=5): super(Postnet, self).__init__() self.convolutions = nn.ModuleList() self.convolutions.append( - ConvBNBlock(mel_dim, 512, kernel_size=5, nonlinear='tanh')) + ConvBNBlock(output_dim, 512, kernel_size=5, activation='tanh')) for _ in range(1, num_convs - 1): self.convolutions.append( - ConvBNBlock(512, 512, kernel_size=5, nonlinear='tanh')) + ConvBNBlock(512, 512, kernel_size=5, activation='tanh')) self.convolutions.append( - ConvBNBlock(512, mel_dim, kernel_size=5, nonlinear=None)) + ConvBNBlock(512, output_dim, kernel_size=5, activation=None)) def forward(self, x): + o = x for layer in self.convolutions: - x = layer(x) - return x + o = layer(o) + return o class Encoder(nn.Module): - def __init__(self, in_features=512): + def __init__(self, output_input_dim=512): super(Encoder, self).__init__() - convolutions = [] + self.convolutions = nn.ModuleList() for _ in range(3): - convolutions.append( - ConvBNBlock(in_features, in_features, 5, 'relu')) - self.convolutions = nn.Sequential(*convolutions) - self.lstm = nn.LSTM(in_features, - int(in_features / 2), + self.convolutions.append( + ConvBNBlock(output_input_dim, output_input_dim, 5, 'relu')) + self.lstm = nn.LSTM(output_input_dim, + int(output_input_dim / 2), num_layers=1, batch_first=True, bidirectional=True) self.rnn_state = None def forward(self, x, input_lengths): - x = self.convolutions(x) - x = x.transpose(1, 2) - x = nn.utils.rnn.pack_padded_sequence(x, + o = x + for layer in self.convolutions: + o = layer(o) + o = o.transpose(1, 2) + o = nn.utils.rnn.pack_padded_sequence(o, input_lengths, batch_first=True) self.lstm.flatten_parameters() - outputs, _ = self.lstm(x) - outputs, _ = nn.utils.rnn.pad_packed_sequence( - outputs, - batch_first=True, - ) - return outputs + o, _ = self.lstm(o) + o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) + return o def inference(self, x): - x = self.convolutions(x) - x = x.transpose(1, 2) + o = x + for layer in self.convolutions: + o = layer(o) + o = x.transpose(1, 2) self.lstm.flatten_parameters() - outputs, _ = self.lstm(x) - return outputs - - def inference_truncated(self, x): - """ - Preserve encoder state for continuous inference - """ - x = self.convolutions(x) - x = x.transpose(1, 2) - self.lstm.flatten_parameters() - outputs, self.rnn_state = self.lstm(x, self.rnn_state) - return outputs + o, _ = self.lstm(o) + return o # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): # Pylint gets confused by PyTorch conventions here #pylint: disable=attribute-defined-outside-init - def __init__(self, in_features, memory_dim, r, attn_type, attn_win, attn_norm, + def __init__(self, input_dim, frame_dim, r, attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, attn_K, separate_stopnet, speaker_embedding_dim): super(Decoder, self).__init__() - self.memory_dim = memory_dim + self.frame_dim = frame_dim self.r_init = r self.r = r - self.encoder_embedding_dim = in_features + self.encoder_embedding_dim = input_dim self.separate_stopnet = separate_stopnet + self.max_decoder_steps = 1000 + self.gate_threshold = 0.5 + + # model dimensions self.query_dim = 1024 self.decoder_rnn_dim = 1024 self.prenet_dim = 256 - self.max_decoder_steps = 1000 - self.gate_threshold = 0.5 + self.attn_dim = 128 self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 # memory -> |Prenet| -> processed_memory - prenet_dim = self.memory_dim - self.prenet = Prenet( - prenet_dim, - prenet_type, - prenet_dropout, - out_features=[self.prenet_dim, self.prenet_dim], - bias=False) + prenet_dim = self.frame_dim + self.prenet = Prenet(prenet_dim, + prenet_type, + prenet_dropout, + out_features=[self.prenet_dim, self.prenet_dim], + bias=False) - self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, + self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim, self.query_dim) self.attention = init_attn(attn_type=attn_type, query_dim=self.query_dim, - embedding_dim=in_features, + embedding_dim=input_dim, attention_dim=128, location_attention=location_attn, attention_location_n_filters=32, @@ -141,15 +137,15 @@ class Decoder(nn.Module): forward_attn_mask=forward_attn_mask, attn_K=attn_K) - self.decoder_rnn = nn.LSTMCell(self.query_dim + in_features, + self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim, self.decoder_rnn_dim, 1) - self.linear_projection = Linear(self.decoder_rnn_dim + in_features, - self.memory_dim * self.r_init) + self.linear_projection = Linear(self.decoder_rnn_dim + input_dim, + self.frame_dim * self.r_init) self.stopnet = nn.Sequential( nn.Dropout(0.1), - Linear(self.decoder_rnn_dim + self.memory_dim * self.r_init, + Linear(self.decoder_rnn_dim + self.frame_dim * self.r_init, 1, bias=True, init_gain='sigmoid')) @@ -161,7 +157,7 @@ class Decoder(nn.Module): def get_go_frame(self, inputs): B = inputs.size(0) memory = torch.zeros(1, device=inputs.device).repeat(B, - self.memory_dim * self.r) + self.frame_dim * self.r) return memory def _init_states(self, inputs, mask, keep_states=False): @@ -187,9 +183,9 @@ class Decoder(nn.Module): Reshape the spectrograms for given 'r' """ # Grouping multiple frames if necessary - if memory.size(-1) == self.memory_dim: + if memory.size(-1) == self.frame_dim: memory = memory.view(memory.shape[0], memory.size(1) // self.r, -1) - # Time first (T_decoder, B, memory_dim) + # Time first (T_decoder, B, frame_dim) memory = memory.transpose(0, 1) return memory @@ -197,22 +193,22 @@ class Decoder(nn.Module): alignments = torch.stack(alignments).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - outputs = outputs.view(outputs.size(0), -1, self.memory_dim) + outputs = outputs.view(outputs.size(0), -1, self.frame_dim) outputs = outputs.transpose(1, 2) return outputs, stop_tokens, alignments def _update_memory(self, memory): if len(memory.shape) == 2: - return memory[:, self.memory_dim * (self.r - 1):] - return memory[:, :, self.memory_dim * (self.r - 1):] + return memory[:, self.frame_dim * (self.r - 1):] + return memory[:, :, self.frame_dim * (self.r - 1):] def decode(self, memory): ''' shapes: - - memory: B x r * self.memory_dim + - memory: B x r * self.frame_dim ''' # self.context: B x D_en - # query_input: B x D_en + (r * self.memory_dim) + # query_input: B x D_en + (r * self.frame_dim) query_input = torch.cat((memory, self.context), -1) # self.query and self.attention_rnn_cell_state : B x D_attn_rnn self.query, self.attention_rnn_cell_state = self.attention_rnn( @@ -235,16 +231,16 @@ class Decoder(nn.Module): # B x (D_decoder_rnn + D_en) decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) - # B x (self.r * self.memory_dim) + # B x (self.r * self.frame_dim) decoder_output = self.linear_projection(decoder_hidden_context) - # B x (D_decoder_rnn + (self.r * self.memory_dim)) + # B x (D_decoder_rnn + (self.r * self.frame_dim)) stopnet_input = torch.cat((self.decoder_hidden, decoder_output), dim=1) if self.separate_stopnet: stop_token = self.stopnet(stopnet_input.detach()) else: stop_token = self.stopnet(stopnet_input) # select outputs for the reduction rate self.r - decoder_output = decoder_output[:, :self.r * self.memory_dim] + decoder_output = decoder_output[:, :self.r * self.frame_dim] return decoder_output, self.attention.attention_weights, stop_token def forward(self, inputs, memories, mask, speaker_embeddings=None): diff --git a/models/tacotron2.py b/models/tacotron2.py index d530774a..3e7adfca 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -29,7 +29,7 @@ class Tacotron2(nn.Module): super(Tacotron2, self).__init__() self.postnet_output_dim = postnet_output_dim self.decoder_output_dim = decoder_output_dim - self.n_frames_per_step = r + self.r = r self.bidirectional_decoder = bidirectional_decoder decoder_dim = 512 if num_speakers > 1 else 512 encoder_dim = 512 if num_speakers > 1 else 512 From 736f169cc99f13e2fd7534df3a43d12147bc367b Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 28 Apr 2020 18:16:37 +0200 Subject: [PATCH 02/16] tf lstm does not match torch lstm wrt bias vectors. So I avoid bias in LSTM as an easy solution. --- layers/tacotron2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 3e439b9b..35a5c0bb 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -61,6 +61,7 @@ class Encoder(nn.Module): int(output_input_dim / 2), num_layers=1, batch_first=True, + bias=False, bidirectional=True) self.rnn_state = None @@ -121,7 +122,8 @@ class Decoder(nn.Module): bias=False) self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim, - self.query_dim) + self.query_dim, + bias=False) self.attention = init_attn(attn_type=attn_type, query_dim=self.query_dim, From de2918c85b5afb2648d2f39a0a47fcee204ba101 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 1 May 2020 14:34:14 +0200 Subject: [PATCH 03/16] bug fixes --- layers/tacotron2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 35a5c0bb..10c03570 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -82,8 +82,8 @@ class Encoder(nn.Module): o = x for layer in self.convolutions: o = layer(o) - o = x.transpose(1, 2) - self.lstm.flatten_parameters() + o = o.transpose(1, 2) + # self.lstm.flatten_parameters() o, _ = self.lstm(o) return o @@ -140,7 +140,8 @@ class Decoder(nn.Module): attn_K=attn_K) self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim, - self.decoder_rnn_dim, 1) + self.decoder_rnn_dim, + bias=False) self.linear_projection = Linear(self.decoder_rnn_dim + input_dim, self.frame_dim * self.r_init) From 9504b71f79cd58cad456654aa4d28740662ff3d8 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 1 May 2020 23:06:51 +0200 Subject: [PATCH 04/16] fix lstm biases True --- layers/tacotron2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 10c03570..4454c89e 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -123,7 +123,7 @@ class Decoder(nn.Module): self.attention_rnn = nn.LSTMCell(self.prenet_dim + input_dim, self.query_dim, - bias=False) + bias=True) self.attention = init_attn(attn_type=attn_type, query_dim=self.query_dim, @@ -141,7 +141,7 @@ class Decoder(nn.Module): self.decoder_rnn = nn.LSTMCell(self.query_dim + input_dim, self.decoder_rnn_dim, - bias=False) + bias=True) self.linear_projection = Linear(self.decoder_rnn_dim + input_dim, self.frame_dim * self.r_init) From 6f5c8773d6486e86418496913f53b0c8ec82d087 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 4 May 2020 21:03:03 +0200 Subject: [PATCH 05/16] enable encoder lstm bias --- layers/tacotron2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 4454c89e..b9aec6fe 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -61,7 +61,7 @@ class Encoder(nn.Module): int(output_input_dim / 2), num_layers=1, batch_first=True, - bias=False, + bias=True, bidirectional=True) self.rnn_state = None From d99fda8e42c0f131ed138a1fb29e188819977093 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 5 May 2020 17:36:12 +0200 Subject: [PATCH 06/16] init batch norm explicit initial values --- layers/common_layers.py | 2 +- layers/tacotron2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index d2afe012..24433269 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -33,7 +33,7 @@ class LinearBN(nn.Module): super(LinearBN, self).__init__() self.linear_layer = torch.nn.Linear( in_features, out_features, bias=bias) - self.batch_normalization = nn.BatchNorm1d(out_features) + self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5) self._init_w(init_gain) def _init_w(self, init_gain): diff --git a/layers/tacotron2.py b/layers/tacotron2.py index b9aec6fe..bdb169be 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -14,7 +14,7 @@ class ConvBNBlock(nn.Module): out_channels, kernel_size, padding=padding) - self.batch_normalization = nn.BatchNorm1d(out_channels) + self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5) self.dropout = nn.Dropout(p=0.5) if activation == 'relu': self.activation = nn.ReLU() From b3ec50b5c4f1bebfea642af0298ce62a5c3bc518 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 6 May 2020 16:37:30 +0200 Subject: [PATCH 07/16] tf bacend for synthesis --- utils/synthesis.py | 91 +++++++++++++++++++++++++++++++++++----------- utils/visual.py | 1 - 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/utils/synthesis.py b/utils/synthesis.py index 9158ef02..0c68dbf2 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -1,3 +1,7 @@ +import pkg_resources +installed = {pkg.key for pkg in pkg_resources.working_set} +if 'tensorflow' in installed: + import tensorflow as tf import torch import numpy as np from .text import text_to_sequence, phoneme_to_sequence @@ -14,23 +18,32 @@ def text_to_seqvec(text, CONFIG, use_cuda): dtype=np.int32) else: seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.characters if 'characters' in CONFIG.keys() else None), dtype=np.int32) - # torch tensor - chars_var = torch.from_numpy(seq).unsqueeze(0) - if use_cuda: - chars_var = chars_var.cuda() - return chars_var.long() + return seq + + +def numpy_to_torch(np_array, dtype, cuda=False): + if np_array is None: + return None + tensor = torch.Tensor(np_array, dtype=dtype) + if cuda: + return tensor.cuda() + return tensor + + +def numpy_to_tf(np_array, dtype): + if np_array is None: + return None + tensor = tf.convert_to_tensor(np_array, dtype=dtype) + return tensor def compute_style_mel(style_wav, ap, use_cuda): - print(style_wav) - style_mel = torch.FloatTensor(ap.melspectrogram( - ap.load_wav(style_wav))).unsqueeze(0) - if use_cuda: - return style_mel.cuda() + style_mel = ap.melspectrogram( + ap.load_wav(style_wav)).expand_dims(0) return style_mel -def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): +def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( inputs, style_mel=style_mel, speaker_ids=speaker_id) @@ -44,11 +57,31 @@ def run_model(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None) return decoder_output, postnet_output, alignments, stop_tokens -def parse_outputs(postnet_output, decoder_output, alignments): +def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): + if CONFIG.use_gst: + raise NotImplemented(' [!] GST inference not implemented for TF') + if truncated: + raise NotImplemented(' [!] Truncated inference not implemented for TF') + # TODO: handle multispeaker case + decoder_output, postnet_output, alignments, stop_tokens = model( + inputs, training=False) + return decoder_output, postnet_output, alignments, stop_tokens + + +def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens): postnet_output = postnet_output[0].data.cpu().numpy() decoder_output = decoder_output[0].data.cpu().numpy() alignment = alignments[0].cpu().data.numpy() - return postnet_output, decoder_output, alignment + stop_tokens = stop_tokens[0].cpu().numpy() + return postnet_output, decoder_output, alignment, stop_tokens + + +def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens): + postnet_output = postnet_output[0].numpy() + decoder_output = decoder_output[0].numpy() + alignment = alignments[0].numpy() + stop_tokens = stop_tokens[0].numpy() + return postnet_output, decoder_output, alignment, stop_tokens def trim_silence(wav, ap): @@ -98,7 +131,8 @@ def synthesis(model, truncated=False, enable_eos_bos_chars=False, #pylint: disable=unused-argument use_griffin_lim=False, - do_trim_silence=False): + do_trim_silence=False, + backend='torch'): """Synthesize voice for the given text. Args: @@ -114,6 +148,7 @@ def synthesis(model, for continuous inference at long texts. enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. do_trim_silence (bool): trim silence after synthesis. + backend (str): tf or torch """ # GST processing style_mel = None @@ -121,15 +156,29 @@ def synthesis(model, style_mel = compute_style_mel(style_wav, ap, use_cuda) # preprocess the given text inputs = text_to_seqvec(text, CONFIG, use_cuda) - speaker_id = id_to_torch(speaker_id) - if speaker_id is not None and use_cuda: - speaker_id = speaker_id.cuda() + # pass tensors to backend + if backend == 'torch': + speaker_id = id_to_torch(speaker_id) + style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) + inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) + inputs = inputs.unsqueeze(0) + else: + # TODO: handle speaker id for tf model + style_mel = numpy_to_tf(style_mel, tf.float32) + inputs = numpy_to_tf(inputs, tf.int32) + inputs = tf.expand_dims(inputs, 0) # synthesize voice - decoder_output, postnet_output, alignments, stop_tokens = run_model( - model, inputs, CONFIG, truncated, speaker_id, style_mel) + if backend == 'torch': + decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( + model, inputs, CONFIG, truncated, speaker_id, style_mel) + postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( + postnet_output, decoder_output, alignments, stop_tokens) + else: + decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( + model, inputs, CONFIG, truncated, speaker_id, style_mel) + postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf( + postnet_output, decoder_output, alignments, stop_tokens) # convert outputs to numpy - postnet_output, decoder_output, alignment = parse_outputs( - postnet_output, decoder_output, alignments) # plot results wav = None if use_griffin_lim: diff --git a/utils/visual.py b/utils/visual.py index 8789cf8f..87fbc8e4 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -61,7 +61,6 @@ def visualize(alignment, postnet_output, stop_tokens, text, hop_length, CONFIG, plt.yticks(range(len(text)), list(text)) plt.colorbar() # plot stopnet predictions - stop_tokens = stop_tokens.squeeze().detach().to('cpu').numpy() plt.subplot(num_plot, 1, 2) plt.plot(range(len(stop_tokens)), list(stop_tokens)) # plot postnet spectrogram From 84c5c4a5871ffe3f87e01feee429a9ffec154c0e Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 12 May 2020 13:46:16 +0200 Subject: [PATCH 08/16] config remove empty chars --- config.json | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/config.json b/config.json index da3fe286..c23bd004 100644 --- a/config.json +++ b/config.json @@ -1,5 +1,5 @@ { - "model": "Tacotron2", + "model": "Tacotron2", "run_name": "ljspeech", "run_description": "tacotron2", @@ -11,12 +11,12 @@ "hop_length": 256, // stft window hop-lengh in ms. "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. - + // Audio processing parameters "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. - + // Silence trimming "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) "trim_db": 60, // threshold for timming silence. Set this according to your dataset. @@ -26,7 +26,7 @@ "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. // MelSpectrogram parameters - "num_mels": 80, // size of the mel spec frame. + "num_mels": 80, // size of the mel spec frame. "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! @@ -50,7 +50,7 @@ // "punctuations":"!'(),-.:;? ", // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" // }, - + // DISTRIBUTED TRAINING "distributed":{ "backend": "nccl", @@ -61,8 +61,8 @@ // TRAINING "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. - "eval_batch_size":16, - "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. + "eval_batch_size":16, + "r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. "gradual_training": [[0, 7, 64], [1, 5, 64], [50000, 3, 32], [130000, 2, 32], [290000, 1, 32]], //set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled. For Tacotron, you might need to reduce the 'batch_size' as you proceeed. "loss_masking": true, // enable / disable loss masking against the sequence padding. "ga_alpha": 10.0, // weight for guided attention loss. If > 0, guided attention is enabled. @@ -80,11 +80,11 @@ "wd": 0.000001, // Weight decay weight. "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "seq_len_norm": false, // Normalize eash sample loss with its length to alleviate imbalanced datasets. Use it if your dataset is small or has skewed distribution of sequence lengths. - + // TACOTRON PRENET - "memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame. + "memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame. "prenet_type": "original", // "original" or "bn". - "prenet_dropout": true, // enable/disable dropout at prenet. + "prenet_dropout": true, // enable/disable dropout at prenet. // ATTENTION "attention_type": "original", // 'original' or 'graves' @@ -98,16 +98,16 @@ "bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset. // STOPNET - "stopnet": true, // Train stopnet predicting the end of synthesis. + "stopnet": true, // Train stopnet predicting the end of synthesis. "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. // TENSORBOARD and LOGGING "print_step": 25, // Number of steps to log traning on console. - "print_eval": false, // If True, it prints loss values in evalulation. + "print_eval": false, // If True, it prints loss values in evalulation. "save_step": 10000, // Number of training steps expected to save traninpg stats and checkpoints. "checkpoint": true, // If true, it saves checkpoints per "save_step" - "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + // DATA LOADING "text_cleaner": "phoneme_cleaners", "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. @@ -119,7 +119,7 @@ // PATHS "output_path": "/home/erogol/Models/LJSpeech/", - + // PHONEMES "phoneme_cache_path": "mozilla_us_phonemes_3", // phoneme computation is slow, therefore, it caches results in the given folder. "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation. From 68dbcee746a775df42250b1940b54e88ae670e62 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 12 May 2020 13:49:49 +0200 Subject: [PATCH 09/16] import condition update for synthesis with TF --- utils/synthesis.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/synthesis.py b/utils/synthesis.py index 0c68dbf2..ae3a7df7 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -1,6 +1,6 @@ import pkg_resources installed = {pkg.key for pkg in pkg_resources.working_set} -if 'tensorflow' in installed: +if 'tensorflow' in installed or 'tensorflow-gpu' in installed: import tensorflow as tf import torch import numpy as np @@ -24,9 +24,9 @@ def text_to_seqvec(text, CONFIG, use_cuda): def numpy_to_torch(np_array, dtype, cuda=False): if np_array is None: return None - tensor = torch.Tensor(np_array, dtype=dtype) + tensor = torch.Tensor(np_array, dtype=dtype) if cuda: - return tensor.cuda() + return tensor.cuda() return tensor From 1cd25ccf0d3842fc954263f5d538827cfb82e040 Mon Sep 17 00:00:00 2001 From: erogol Date: Tue, 12 May 2020 16:16:55 +0200 Subject: [PATCH 10/16] bug fix --- utils/synthesis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/synthesis.py b/utils/synthesis.py index ae3a7df7..188a3acf 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -24,7 +24,7 @@ def text_to_seqvec(text, CONFIG, use_cuda): def numpy_to_torch(np_array, dtype, cuda=False): if np_array is None: return None - tensor = torch.Tensor(np_array, dtype=dtype) + tensor = torch.as_tensor(np_array, dtype=dtype) if cuda: return tensor.cuda() return tensor From d5d9e6e8ea87995a5679d78b24e7df2a3c88e185 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 13 May 2020 13:52:17 +0200 Subject: [PATCH 11/16] bug fix --- utils/synthesis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/synthesis.py b/utils/synthesis.py index 188a3acf..3903ba44 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -64,7 +64,7 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No raise NotImplemented(' [!] Truncated inference not implemented for TF') # TODO: handle multispeaker case decoder_output, postnet_output, alignments, stop_tokens = model( - inputs, training=False) + inputs) return decoder_output, postnet_output, alignments, stop_tokens From 67397be1c096a37e0d4e97a729f8c1c144feca35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 18 May 2020 11:02:36 +0200 Subject: [PATCH 12/16] tf folder add --- tf/README.md | 4 + tf/convert_tacotron2_torch_to_tf.py | 196 +++++++ tf/layers/common_layers.py | 258 ++++++++++ tf/layers/tacotron2.py | 231 +++++++++ tf/models/tacotron2.py | 72 +++ tf/notebooks/Benchmark-TTS_tf.ipynb | 708 ++++++++++++++++++++++++++ tf/requirements | 2 + tf/utils/convert_torch_to_tf_utils.py | 83 +++ tf/utils/generic_utils.py | 105 ++++ tf/utils/tf_utils.py | 8 + 10 files changed, 1667 insertions(+) create mode 100644 tf/README.md create mode 100644 tf/convert_tacotron2_torch_to_tf.py create mode 100644 tf/layers/common_layers.py create mode 100644 tf/layers/tacotron2.py create mode 100644 tf/models/tacotron2.py create mode 100644 tf/notebooks/Benchmark-TTS_tf.ipynb create mode 100644 tf/requirements create mode 100644 tf/utils/convert_torch_to_tf_utils.py create mode 100644 tf/utils/generic_utils.py create mode 100644 tf/utils/tf_utils.py diff --git a/tf/README.md b/tf/README.md new file mode 100644 index 00000000..24e09a06 --- /dev/null +++ b/tf/README.md @@ -0,0 +1,4 @@ +## Utilities to Convert Models to Tensorflow2 +You can find some utilities to convert Torch models to Tensorflow with an experimental Tacotron2 implemenation in Tensorflow2 (>=2.2). However, our released Torch models may not work with this module due to additional changes layer naming convention. Therefore, you need to train new models to be converted to TF. + +This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own. \ No newline at end of file diff --git a/tf/convert_tacotron2_torch_to_tf.py b/tf/convert_tacotron2_torch_to_tf.py new file mode 100644 index 00000000..512b0a4d --- /dev/null +++ b/tf/convert_tacotron2_torch_to_tf.py @@ -0,0 +1,196 @@ +# %% +import sys +sys.path.append('/home/erogol/Projects') +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '' +# %% +import argparse +import numpy as np +import torch +import tensorflow as tf +from fuzzywuzzy import fuzz + +from TTS.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.utils.generic_utils import setup_model, count_parameters +from TTS.utils.io import load_config +from TTS_tf.models.tacotron2 import Tacotron2 +from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name +from TTS_tf.utils.generic_utils import save_checkpoint + + +parser = argparse.ArgumentParser() +parser.add_argument( + '--torch_model_path', + type=str, + help='Path to target torch model to be converted to TF.') +parser.add_argument( + '--config_path', + type=str, + help='Path to config file of torch model.') +parser.add_argument( + '--output_path', + type=str, + help='path to save TF model weights.') +args = parser.parse_args() + +# load model config +config_path = args.config_path +c = load_config(config_path) +num_speakers = 0 + +# init torch model +num_chars = len(phonemes) if c.use_phonemes else len(symbols) +model = setup_model(num_chars, num_speakers, c) +checkpoint = torch.load(args.torch_model_path, map_location=torch.device('cpu')) +state_dict = checkpoint['model'] +model.load_state_dict(state_dict) + +# init tf model +model_tf = Tacotron2(num_chars=num_chars, + num_speakers=num_speakers, + r=model.decoder.r, + postnet_output_dim=c.audio['num_mels'], + decoder_output_dim=c.audio['num_mels'], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder) + +# set initial layer mapping - these are not captured by the below heuristic approach +# TODO: set layer names so that we can remove these manual matching +common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' +var_map = [ + ('tacotron2/embedding/embeddings:0', 'embedding.weight'), + ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', 'encoder.lstm.weight_ih_l0'), + ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0'), + ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', 'encoder.lstm.weight_ih_l0_reverse'), + ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0_reverse'), + ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0', ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')), + ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0', ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')), + ('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'), + ('decoder/linear_projection/kernel:0', 'decoder.linear_projection.linear_layer.weight'), + ('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight') +] + + +# %% +# get tf_model graph +input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs() +mel_pred = model_tf(input_ids, training=False) + +# get tf variables +tf_vars = model_tf.weights + +# match variable names with fuzzy logic +torch_var_names = list(state_dict.keys()) +tf_var_names = [we.name for we in model_tf.weights] +for tf_name in tf_var_names: + # skip re-mapped layer names + if tf_name in [name[0] for name in var_map]: + continue + tf_name_edited = convert_tf_name(tf_name) + ratios = [fuzz.ratio(torch_name, tf_name_edited) for torch_name in torch_var_names] + max_idx = np.argmax(ratios) + matching_name = torch_var_names[max_idx] + del torch_var_names[max_idx] + var_map.append((tf_name, matching_name)) + + +# %% +# print variable match +from pprint import pprint +pprint(var_map) +pprint(torch_var_names) + +# pass weights +tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict) + +# Compare TF and TORCH models +# %% +# check embedding outputs +model.eval() +input_ids = torch.randint(0, 24, (1, 128)).long() + +o_t = model.embedding(input_ids) +o_tf = model_tf.embedding(input_ids.detach().numpy()) +assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum() + +# compare encoder outputs +oo_en = model.encoder.inference(o_t.transpose(1,2)) +ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False) +assert compare_torch_tf(oo_en, ooo_en) < 1e-5 + +# compare decoder.attention_rnn +inp = torch.rand([1, 768]) +inp_tf = inp.numpy() +model.decoder._init_states(oo_en, mask=None) +output, cell_state = model.decoder.attention_rnn(inp) +states = model_tf.decoder.build_decoder_initial_states(1,512,128) +output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False) +assert compare_torch_tf(output, output_tf).mean() < 1e-5 + +# compare decoder.attention +query = output +inputs = torch.rand([1, 128, 512]) +query_tf = query.detach().numpy() +inputs_tf = inputs.numpy() + +model.decoder.attention.init_states(inputs) +processes_inputs = model.decoder.attention.preprocess_inputs(inputs) +loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs) +context = model.decoder.attention(query, inputs, processes_inputs, None) + +model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf)) +loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf) +context_tf = model_tf.decoder.attention(query_tf, training=False) + +assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5 +assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5 +assert compare_torch_tf(context, context_tf) < 1e-5 + +# compare decoder.decoder_rnn +input = torch.rand([1, 1536]) +input_tf = input.numpy() +model.decoder._init_states(oo_en, mask=None) +output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) +states = model_tf.decoder.build_decoder_initial_states(1,512,128) +output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False) +assert abs(input - input_tf).mean() < 1e-5 +assert compare_torch_tf(output, output_tf).mean() < 1e-5 + +# compare decoder.linear_projection +input = torch.rand([1, 1536]) +input_tf = input.numpy() +output = model.decoder.linear_projection(input) +output_tf = model_tf.decoder.linear_projection(input_tf, training=False) +assert compare_torch_tf(output, output_tf) < 1e-5 + +# compare decoder outputs +model.decoder.max_decoder_steps = 100 +model_tf.decoder.set_max_decoder_steps(100) +output, align, stop = model.decoder.inference(oo_en) +states = model_tf.decoder.build_decoder_initial_states(1,512,128) +output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False) +assert compare_torch_tf(output.transpose(1,2), output_tf) < 1e-4 + +# compare the whole model output +outputs_torch = model.inference(input_ids) +outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy())) +print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean() ) +assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5 +assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4 + +# %% +# save tf model +save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'], + checkpoint['r'], args.output_path) +print(' > Model conversion is successfully completed :).') + diff --git a/tf/layers/common_layers.py b/tf/layers/common_layers.py new file mode 100644 index 00000000..fba06e0b --- /dev/null +++ b/tf/layers/common_layers.py @@ -0,0 +1,258 @@ +import tensorflow as tf +from tensorflow import keras +from tensorflow.python.ops import math_ops +# from tensorflow_addons.seq2seq import BahdanauAttention + +from TTS.tf.utils.tf_utils import shape_list + + +class Linear(keras.layers.Layer): + def __init__(self, units, use_bias, **kwargs): + super(Linear, self).__init__(**kwargs) + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') + self.activation = keras.layers.ReLU() + + def call(self, x, training=None): + """ + shapes: + x: B x T x C + """ + return self.activation(self.linear_layer(x)) + + +class LinearBN(keras.layers.Layer): + def __init__(self, units, use_bias, **kwargs): + super(LinearBN, self).__init__(**kwargs) + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') + self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization') + self.activation = keras.layers.ReLU() + + def call(self, x, training=None): + """ + shapes: + x: B x T x C + """ + out = self.linear_layer(x) + out = self.batch_normalization(out, training=training) + return self.activation(out) + + +class Prenet(keras.layers.Layer): + def __init__(self, + prenet_type, + prenet_dropout, + units, + bias, + **kwargs): + super(Prenet, self).__init__(**kwargs) + self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout + self.linear_layers = [] + if prenet_type == "bn": + self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + elif prenet_type == "original": + self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + else: + raise RuntimeError(' [!] Unknown prenet type.') + if prenet_dropout: + self.dropout = keras.layers.Dropout(rate=0.5) + + def call(self, x, training=None): + """ + shapes: + x: B x T x C + """ + for linear in self.linear_layers: + if self.prenet_dropout: + x = self.dropout(linear(x), training=training) + else: + x = linear(x) + return x + + +def _sigmoid_norm(score): + attn_weights = tf.nn.sigmoid(score) + attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True) + return attn_weights + + +class Attention(keras.layers.Layer): + """TODO: implement forward_attention""" + """TODO: location sensitive attention""" + """TODO: implement attention windowing """ + def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, + loc_attn_kernel_size, use_windowing, norm, use_forward_attn, + use_trans_agent, use_forward_attn_mask, **kwargs): + super(Attention, self).__init__(**kwargs) + self.use_loc_attn = use_loc_attn + self.loc_attn_n_filters = loc_attn_n_filters + self.loc_attn_kernel_size = loc_attn_kernel_size + self.use_windowing = use_windowing + self.norm = norm + self.use_forward_attn = use_forward_attn + self.use_trans_agent = use_trans_agent + self.use_forward_attn_mask = use_forward_attn_mask + self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer') + self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer') + self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer') + if use_loc_attn: + self.location_conv1d = keras.layers.Conv1D( + filters=loc_attn_n_filters, + kernel_size=loc_attn_kernel_size, + padding='same', + use_bias=False, + name='location_layer/location_conv1d') + self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense') + if norm == 'softmax': + self.norm_func = tf.nn.softmax + elif norm == 'sigmoid': + self.norm_func = _sigmoid_norm + else: + raise ValueError("Unknown value for attention norm type") + + def init_states(self, batch_size, value_length): + states = () + if self.use_loc_attn: + attention_cum = tf.zeros([batch_size, value_length]) + attention_old = tf.zeros([batch_size, value_length]) + states = (attention_cum, attention_old) + return states + + def process_values(self, values): + """ cache values for decoder iterations """ + self.processed_values = self.inputs_layer(values) + self.values = values + + def get_loc_attn(self, query, states): + """ compute location attention, query layer and + unnorm. attention weights""" + attention_cum, attention_old = states + attn_cat = tf.stack([attention_old, attention_cum], + axis=2) + + processed_query = self.query_layer(tf.expand_dims(query, 1)) + processed_attn = self.location_dense(self.location_conv1d(attn_cat)) + score = self.v( + tf.nn.tanh(self.processed_values + processed_query + + processed_attn)) + score = tf.squeeze(score, axis=2) + return score, processed_query + + def get_attn(self, query): + """ compute query layer and unnormalized attention weights """ + processed_query = self.query_layer(tf.expand_dims(query, 1)) + score = self.v(tf.nn.tanh(self.processed_values + processed_query)) + score = tf.squeeze(score, axis=2) + return score, processed_query + + def apply_score_masking(self, score, mask): + """ ignore sequence paddings """ + padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) + # Bias so padding positions do not contribute to attention distribution. + score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) + return score + + def call(self, query, states): + """ + shapes: + query: B x D + """ + if self.use_loc_attn: + score, processed_query = self.get_loc_attn(query, states) + else: + score, processed_query = self.get_attn(query) + + # TODO: masking + # if mask is not None: + # self.apply_score_masking(score, mask) + # attn_weights shape == (batch_size, max_length, 1) + + attn_weights = self.norm_func(score) + + # update attention states + if self.use_loc_attn: + states = (states[0] + attn_weights, attn_weights) + else: + states = () + + # context_vector shape after sum == (batch_size, hidden_size) + context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False) + context_vector = tf.squeeze(context_vector, axis=1) + return context_vector, attn_weights, states + + +# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b): +# dtype = processed_query.dtype +# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1] +# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2]) + + +# class LocationSensitiveAttention(BahdanauAttention): +# def __init__(self, +# units, +# memory=None, +# memory_sequence_length=None, +# normalize=False, +# probability_fn="softmax", +# kernel_initializer="glorot_uniform", +# dtype=None, +# name="LocationSensitiveAttention", +# location_attention_filters=32, +# location_attention_kernel_size=31): + +# super(LocationSensitiveAttention, +# self).__init__(units=units, +# memory=memory, +# memory_sequence_length=memory_sequence_length, +# normalize=normalize, +# probability_fn='softmax', ## parent module default +# kernel_initializer=kernel_initializer, +# dtype=dtype, +# name=name) +# if probability_fn == 'sigmoid': +# self.probability_fn = lambda score, _: self._sigmoid_normalization(score) +# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False) +# self.location_dense = keras.layers.Dense(units, use_bias=False) +# # self.v = keras.layers.Dense(1, use_bias=True) + +# def _location_sensitive_score(self, processed_query, keys, processed_loc): +# processed_query = tf.expand_dims(processed_query, 1) +# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2]) + +# def _location_sensitive(self, alignment_cum, alignment_old): +# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2) +# return self.location_dense(self.location_conv(alignment_cat)) + +# def _sigmoid_normalization(self, score): +# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True) + +# # def _apply_masking(self, score, mask): +# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) +# # # Bias so padding positions do not contribute to attention distribution. +# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) +# # return score + +# def _calculate_attention(self, query, state): +# alignment_cum, alignment_old = state[:2] +# processed_query = self.query_layer( +# query) if self.query_layer else query +# processed_loc = self._location_sensitive(alignment_cum, alignment_old) +# score = self._location_sensitive_score( +# processed_query, +# self.keys, +# processed_loc) +# alignment = self.probability_fn(score, state) +# alignment_cum = alignment_cum + alignment +# state[0] = alignment_cum +# state[1] = alignment +# return alignment, state + +# def compute_context(self, alignments): +# expanded_alignments = tf.expand_dims(alignments, 1) +# context = tf.matmul(expanded_alignments, self.values) +# context = tf.squeeze(context, [1]) +# return context + +# # def call(self, query, state): +# # alignment, next_state = self._calculate_attention(query, state) +# # return alignment, next_state diff --git a/tf/layers/tacotron2.py b/tf/layers/tacotron2.py new file mode 100644 index 00000000..4d787e83 --- /dev/null +++ b/tf/layers/tacotron2.py @@ -0,0 +1,231 @@ + +import tensorflow as tf +from tensorflow import keras +from TTS.tf.utils.tf_utils import shape_list +from TTS.tf.layers.common_layers import Prenet, Attention +# from tensorflow_addons.seq2seq import AttentionWrapper + + +class ConvBNBlock(keras.layers.Layer): + def __init__(self, filters, kernel_size, activation, **kwargs): + super(ConvBNBlock, self).__init__(**kwargs) + self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d') + self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization') + self.dropout = keras.layers.Dropout(rate=0.5, name='dropout') + self.activation = keras.layers.Activation(activation, name='activation') + + def call(self, x, training=None): + o = self.convolution1d(x) + o = self.batch_normalization(o, training=training) + o = self.activation(o) + o = self.dropout(o, training=training) + return o + + +class Postnet(keras.layers.Layer): + def __init__(self, output_filters, num_convs, **kwargs): + super(Postnet, self).__init__(**kwargs) + self.convolutions = [] + self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0')) + for idx in range(1, num_convs - 1): + self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}')) + self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}')) + + def call(self, x, training=None): + o = x + for layer in self.convolutions: + o = layer(o, training=training) + return o + + +class Encoder(keras.layers.Layer): + def __init__(self, output_input_dim, **kwargs): + super(Encoder, self).__init__(**kwargs) + self.convolutions = [] + for idx in range(3): + self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}')) + self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm') + + def call(self, x, training=None): + o = x + for layer in self.convolutions: + o = layer(o, training=training) + o = self.lstm(o) + return o + + +class Decoder(keras.layers.Layer): + def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, + prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, + use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs): + super(Decoder, self).__init__(**kwargs) + self.frame_dim = frame_dim + self.r_init = tf.constant(r, dtype=tf.int32) + self.r = tf.constant(r, dtype=tf.int32) + self.separate_stopnet = separate_stopnet + self.max_decoder_steps = tf.constant(1000, dtype=tf.int32) + self.stop_thresh = tf.constant(0.5, dtype=tf.float32) + + # model dimensions + self.query_dim = 1024 + self.decoder_rnn_dim = 1024 + self.prenet_dim = 256 + self.attn_dim = 128 + self.p_attention_dropout = 0.1 + self.p_decoder_dropout = 0.1 + + self.prenet = Prenet(prenet_type, + prenet_dropout, + [self.prenet_dim, self.prenet_dim], + bias=False, + name='prenet') + self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', ) + self.attention_rnn_dropout = keras.layers.Dropout(0.5) + + # TODO: implement other attn options + self.attention = Attention(attn_dim=self.attn_dim, + use_loc_attn=True, + loc_attn_n_filters=32, + loc_attn_kernel_size=31, + use_windowing=False, + norm=attn_norm, + use_forward_attn=use_forward_attn, + use_trans_agent=use_trans_agent, + use_forward_attn_mask=use_forward_attn_mask, + name='attention') + self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn') + self.decoder_rnn_dropout = keras.layers.Dropout(0.5) + self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer') + self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer') + + + def set_max_decoder_steps(self, new_max_steps): + self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32) + + def set_r(self, new_r): + self.r = tf.constant(new_r, dtype=tf.int32) + + def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): + zero_frame = tf.zeros([batch_size, self.frame_dim]) + zero_context = tf.zeros([batch_size, memory_dim]) + attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) + decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) + attention_states = self.attention.init_states(batch_size, memory_length) + return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states + + def step(self, prenet_next, states, + memory_seq_length=None, training=None): + _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states + attention_rnn_input = tf.concat([prenet_next, context_next], -1) + attention_rnn_output, attention_rnn_state = \ + self.attention_rnn(attention_rnn_input, + attention_rnn_state, training=training) + attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training) + context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training) + decoder_rnn_input = tf.concat([attention_rnn_output, context], -1) + decoder_rnn_output, decoder_rnn_state = \ + self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training) + decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training) + linear_projection_input = tf.concat([decoder_rnn_output, context], -1) + output_frame = self.linear_projection(linear_projection_input, training=training) + stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1) + stopnet_output = self.stopnet(stopnet_input, training=training) + output_frame = output_frame[:, :self.r * self.frame_dim] + states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states) + return output_frame, stopnet_output, states, attention + + def decode(self, memory, states, frames, memory_seq_length=None): + B, T, D = shape_list(memory) + num_iter = shape_list(frames)[1] // self.r + # init states + frame_zero = tf.expand_dims(states[0], 1) + frames = tf.concat([frame_zero, frames], axis=1) + outputs = tf.TensorArray(dtype=tf.float32, size=num_iter) + attentions = tf.TensorArray(dtype=tf.float32, size=num_iter) + stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter) + # pre-computes + self.attention.process_values(memory) + prenet_output = self.prenet(frames, training=True) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): + prenet_next = prenet_output[:, step] + output, stop_token, states, attention = self.step(prenet_next, + states, + memory_seq_length) + outputs = outputs.write(step, output) + attentions = attentions.write(step, attention) + stop_tokens = stop_tokens.write(step, stop_token) + return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions + _, memory, _, states, outputs, stop_tokens, attentions = \ + tf.while_loop(lambda *arg: True, + _body, + loop_vars=(step_count, memory, prenet_output, states, outputs, + stop_tokens, attentions), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=num_iter) + + outputs = outputs.stack() + attentions = attentions.stack() + stop_tokens = stop_tokens.stack() + outputs = tf.transpose(outputs, [1, 0, 2]) + attentions = tf.transpose(attentions, [1, 0 ,2]) + stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) + stop_tokens = tf.squeeze(stop_tokens, axis=2) + outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + def decode_inference(self, memory, states): + B, T, D = shape_list(memory) + # init states + outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + # pre-computes + self.attention.process_values(memory) + + # iter vars + stop_flag = tf.constant(False, dtype=tf.bool) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): + frame_next = states[0] + prenet_next = self.prenet(frame_next, training=False) + output, stop_token, states, attention = self.step(prenet_next, + states, + None, + training=False) + stop_token = tf.math.sigmoid(stop_token) + outputs = outputs.write(step, output) + attentions = attentions.write(step, attention) + stop_tokens = stop_tokens.write(step, stop_token) + stop_flag = tf.greater(stop_token, self.stop_thresh) + stop_flag = tf.reduce_all(stop_flag) + return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag + + cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) + _, memory, states, outputs, stop_tokens, attentions, stop_flag = \ + tf.while_loop(cond, + _body, + loop_vars=(step_count, memory, states, outputs, + stop_tokens, attentions, stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps) + + outputs = outputs.stack() + attentions = attentions.stack() + stop_tokens = stop_tokens.stack() + + outputs = tf.transpose(outputs, [1, 0, 2]) + attentions = tf.transpose(attentions, [1, 0, 2]) + stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) + stop_tokens = tf.squeeze(stop_tokens, axis=2) + outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + def call(self, memory, states, frames=None, memory_seq_length=None, training=False): + if training: + return self.decode(memory, states, frames, memory_seq_length) + return self.decode_inference(memory, states) \ No newline at end of file diff --git a/tf/models/tacotron2.py b/tf/models/tacotron2.py new file mode 100644 index 00000000..8ddee666 --- /dev/null +++ b/tf/models/tacotron2.py @@ -0,0 +1,72 @@ +import tensorflow as tf +from tensorflow import keras + +from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet +from TTS.tf.utils.tf_utils import shape_list + + +class Tacotron2(keras.models.Model): + def __init__(self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type='original', + attn_win=False, + attn_norm="softmax", + attn_K=4, + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + separate_stopnet=True, + bidirectional_decoder=False): + super(Tacotron2, self).__init__() + self.r = r + self.decoder_output_dim = decoder_output_dim + self.postnet_output_dim = postnet_output_dim + self.bidirectional_decoder = bidirectional_decoder + self.num_speakers = num_speakers + self.speaker_embed_dim = 256 + + self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding') + self.encoder = Encoder(512, name='encoder') + # TODO: most of the decoder args have no use at the momment + self.decoder = Decoder(decoder_output_dim, r, attn_type=attn_type, use_attn_win=attn_win, attn_norm=attn_norm, prenet_type=prenet_type, + prenet_dropout=prenet_dropout, use_forward_attn=forward_attn, use_trans_agent=trans_agent, use_forward_attn_mask=forward_attn_mask, + use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim) + self.postnet = Postnet(postnet_output_dim, 5, name='postnet') + + def call(self, characters, text_lengths=None, frames=None, training=None): + if training == True: + return self.training(characters, text_lengths, frames) + else: + return self.inference(characters) + + def training(self, characters, text_lengths, frames): + B, T = shape_list(characters) + embedding_vectors = self.embedding(characters, training=True) + encoder_output = self.encoder(embedding_vectors, training=True) + decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) + decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True) + postnet_frames = self.postnet(decoder_frames, training=True) + output_frames = decoder_frames + postnet_frames + return decoder_frames, output_frames, attentions, stop_tokens + + def inference(self, characters): + B, T = shape_list(characters) + embedding_vectors = self.embedding(characters, training=False) + encoder_output = self.encoder(embedding_vectors, training=False) + decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) + decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) + postnet_frames = self.postnet(decoder_frames, training=False) + output_frames = decoder_frames + postnet_frames + print(output_frames.shape) + return decoder_frames, output_frames, attentions, stop_tokens + + + + diff --git a/tf/notebooks/Benchmark-TTS_tf.ipynb b/tf/notebooks/Benchmark-TTS_tf.ipynb new file mode 100644 index 00000000..5531460e --- /dev/null +++ b/tf/notebooks/Benchmark-TTS_tf.ipynb @@ -0,0 +1,708 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "This is to test TTS models with benchmark sentences for speech synthesis.\n", + "\n", + "Before running this script please DON'T FORGET: \n", + "- to set file paths.\n", + "- to download related model files from TTS and PWGAN.\n", + "- download or clone related repos, linked below.\n", + "- setup the repositories. ```python setup.py install```\n", + "- to checkout right commit versions (given next to the model) of TTS and PWGAN.\n", + "- to set the right paths in the cell below.\n", + "\n", + "Repositories:\n", + "- TTS: https://github.com/mozilla/TTS\n", + "- PWGAN: https://github.com/erogol/ParallelWaveGAN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false", + "scrolled": true + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import os\n", + "\n", + "# you may need to change this depending on your system\n", + "os.environ['CUDA_VISIBLE_DEVICES']='1'\n", + "\n", + "import sys\n", + "import io\n", + "import torch \n", + "import tensorflow as tf\n", + "print(tf.config.list_physical_devices('GPU'))\n", + "\n", + "import time\n", + "import json\n", + "import yaml\n", + "import numpy as np\n", + "from collections import OrderedDict\n", + "import matplotlib.pyplot as plt\n", + "plt.rcParams[\"figure.figsize\"] = (16,5)\n", + "\n", + "import librosa\n", + "import librosa.display\n", + "\n", + "from TTS.tf.models.tacotron2 import Tacotron2\n", + "from TTS.tf.utils.generic_utils import setup_model, load_checkpoint\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.io import load_config\n", + "from TTS.utils.synthesis import synthesis\n", + "from TTS.utils.visual import visualize\n", + "\n", + "import IPython\n", + "from IPython.display import Audio\n", + "\n", + "%matplotlib agg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n", + " t_1 = time.time()\n", + " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens, inputs = synthesis(model, text, CONFIG, use_cuda, ap, None, None, False, CONFIG.enable_eos_bos_chars, use_gl, backend=BACKEND)\n", + " if CONFIG.model == \"Tacotron\" and not use_gl:\n", + " # coorect the normalization differences b/w TTS and the Vocoder.\n", + " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", + " print(mel_postnet_spec.shape)\n", + " print(\"max- \", mel_postnet_spec.max(), \" -- min- \", mel_postnet_spec.min())\n", + " if not use_gl:\n", + " waveform = vocoder_model.inference(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0))\n", + " mel_postnet_spec = ap._denormalize(mel_postnet_spec.T).T\n", + " if use_cuda and not use_gl:\n", + " waveform = waveform.cpu()\n", + " waveform = waveform.numpy()\n", + " waveform = waveform.squeeze()\n", + " rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n", + " print(waveform.shape)\n", + " print(\" > Run-time: {}\".format(time.time() - t_1))\n", + " print(\" > Real-time factor: {}\".format(rtf))\n", + " if figures: \n", + " visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec.T).T) \n", + " IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=True)) \n", + " os.makedirs(OUT_FOLDER, exist_ok=True)\n", + " file_name = text.replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n", + " out_path = os.path.join(OUT_FOLDER, file_name)\n", + " ap.save_wav(waveform, out_path)\n", + " return alignment, mel_postnet_spec, stop_tokens, waveform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# Set constants\n", + "ROOT_PATH = '../tf_model/'\n", + "MODEL_PATH = ROOT_PATH + '/tts_tf_checkpoint_360000.pkl'\n", + "CONFIG_PATH = ROOT_PATH + '/config.json'\n", + "OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n", + "CONFIG = load_config(CONFIG_PATH)\n", + "# Run FLAGs\n", + "use_cuda = True\n", + "# Set the vocoder\n", + "use_gl = True # use GL if True\n", + "BACKEND = 'tf'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false", + "scrolled": true + }, + "outputs": [], + "source": [ + "from TTS.utils.text.symbols import symbols, phonemes, make_symbols\n", + "from TTS.tf.utils.convert_torch_to_tf_utils import tf_create_dummy_inputs\n", + "c = CONFIG\n", + "num_speakers = 0\n", + "r = 1\n", + "num_chars = len(phonemes) if c.use_phonemes else len(symbols)\n", + "model = setup_model(num_chars, num_speakers, c)\n", + "\n", + "# before loading weights you need to run the model once to generate the variables\n", + "input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()\n", + "mel_pred = model(input_ids, training=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "model = load_checkpoint(model, MODEL_PATH)\n", + "# model = tf.function(model, experimental_relax_shapes=True)\n", + "ap = AudioProcessor(**CONFIG.audio) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# wrapper class to use tf.function\n", + "class ModelInference(tf.keras.Model):\n", + " def __init__(self, model):\n", + " super(ModelInference, self).__init__()\n", + " self.model = model\n", + " \n", + " @tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32)])\n", + " def call(self, characters):\n", + " return self.model(characters, training=False)\n", + " \n", + "model = ModelInference(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# LOAD WAVERNN\n", + "if use_gl == False:\n", + " from parallel_wavegan.models import ParallelWaveGANGenerator, MelGANGenerator\n", + " \n", + " vocoder_model = MelGANGenerator(**VOCODER_CONFIG[\"generator_params\"])\n", + " vocoder_model.load_state_dict(torch.load(VOCODER_MODEL_PATH, map_location=\"cpu\")[\"model\"][\"generator\"])\n", + " vocoder_model.remove_weight_norm()\n", + " ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio']) \n", + " if use_cuda:\n", + " vocoder_model.cuda()\n", + " vocoder_model.eval();\n", + " print(count_parameters(vocoder_model))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Comparision with https://mycroft.ai/blog/available-voices/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Bill got in the habit of asking himself “Is that thought true?” and if he wasn’t absolutely certain it was, he just let it go.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### https://espnet.github.io/icassp2020-tts/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The Commission also recommends\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"As a result of these studies, the planning document submitted by the Secretary of the Treasury to the Bureau of the Budget on August thirty-one.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The FBI now transmits information on all defectors, a category which would, of course, have included Oswald.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"they seem unduly restrictive in continuing to require some manifestation of animus against a Government official.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"and each agency given clear understanding of the assistance which the Secret Service expects.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Other examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The human voice is the most perfect instrument of all.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"This cake is great. It's so delicious and moist.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Comparison with https://keithito.github.io/audio-samples/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Generative adversarial network or variational auto-encoder.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The buses aren't the problem, they actually provide a solution.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Comparison with https://google.github.io/tacotron/publications/tacotron/index.html" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Generative adversarial network or variational auto-encoder.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \" He has read the whole thing.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"He reads books.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Thisss isrealy awhsome.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"This is your internet browser, Firefox.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"This is your internet browser Firefox.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The quick brown fox jumps over the lazy dog.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Does the quick brown fox jump over the lazy dog?\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Eren, how are you?\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Hard Sentences" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Encouraged, he started with a minute a day.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"If he decided to watch TV he really watched it.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# for twb dataset\n", + "sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "wavs = []\n", + "model.eval()\n", + "model.decoder.prenet.eval()\n", + "model.decoder.max_decoder_steps = 2000\n", + "# model.decoder.prenet.train()\n", + "speaker_id = None\n", + "sentence = '''This is App Store Optimization report.\n", + "The first tab on the report is App Details. App details report is updated weekly and Datetime column shows the latest report update date. The widget displays the app icon, respective app version, visual assets on the store, app description, latest app update date on the Appstore/Google PlayStore and what’s new section.\n", + "In App Details tab, you can see not only your app but all Delivery Hero apps since we think it can be inspiring to see the other apps, their description and screenshots. \n", + "Product name is the actual app name on the AppStore or Google Play Store.\n", + "Screenshot URLs column display the actual screenshots on the store for the current version. No resizing is done. If you click on the screenshot, you can see it in full-size.\n", + "Current release date show the latest app update date when the query is run. Here we see that Appetito24 Android is updated to app version 4.6.3.2 on 28th of March.\n", + "If the description is too long, clarisights is not able to display the full description; however, if you select description and current_release_date cells to copy and paste it to a text editor, you'll see the full description.\n", + "If you scroll down in the widget, you can see the older app versions for the same apps. Or you can filter Datetime to see a specific timeframe and the apps’ Store presence back then.\n", + "You can also filter for a specific app using Product Name.\n", + "If the description is too long, clarisights is not able to display the full description; however, if you select description and current_release_date cells to copy and paste it to a text editor, you'll see the full description.\n", + "'''\n", + "\n", + "for s in sentence.split('\\n'):\n", + " print(s)\n", + " align, spec, stop_tokens, wav = tts(model, s, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)\n", + " wavs = np.concatenate([wavs, np.zeros(int(ap.sample_rate * 0.5)), wav])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tf/requirements b/tf/requirements new file mode 100644 index 00000000..75882a1d --- /dev/null +++ b/tf/requirements @@ -0,0 +1,2 @@ +fuzzywuzzy +tensorflow>=2.2.0 \ No newline at end of file diff --git a/tf/utils/convert_torch_to_tf_utils.py b/tf/utils/convert_torch_to_tf_utils.py new file mode 100644 index 00000000..732f2fb5 --- /dev/null +++ b/tf/utils/convert_torch_to_tf_utils.py @@ -0,0 +1,83 @@ +import numpy as np +import torch +import re +import tensorflow as tf +import tensorflow.keras.backend as K + + +def tf_create_dummy_inputs(): + """ Create dummy inputs for TF Tacotron2 model """ + batch_size = 4 + max_input_length = 32 + max_mel_length = 128 + pad = 1 + n_chars = 24 + input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32) + input_lengths = np.random.randint(0, high=max_input_length+1 + pad, size=[batch_size]) + input_lengths[-1] = max_input_length + input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32) + mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80]) + mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size]) + mel_lengths[-1] = max_mel_length + mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32) + return input_ids, input_lengths, mel_outputs, mel_lengths + + +def compare_torch_tf(torch_tensor, tf_tensor): + """ Compute the average absolute difference b/w torch and tf tensors """ + return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean() + + +def convert_tf_name(tf_name): + """ Convert certain patterns in TF layer names to Torch patterns """ + tf_name_tmp = tf_name + tf_name_tmp = tf_name_tmp.replace(':0', '') + tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0') + tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1') + tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh') + tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight') + tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight') + tf_name_tmp = tf_name_tmp.replace('/beta', '/bias') + tf_name_tmp = tf_name_tmp.replace('/', '.') + return tf_name_tmp + + +def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): + """ Transfer weigths from torch state_dict to TF variables """ + print(" > Passing weights from Torch to TF ...") + for tf_var in tf_vars: + torch_var_name = var_map_dict[tf_var.name] + print(f' | > {tf_var.name} <-- {torch_var_name}') + # if tuple, it is a bias variable + if type(torch_var_name) is not tuple: + torch_layer_name = '.'.join(torch_var_name.split('.')[-2:]) + torch_weight = state_dict[torch_var_name] + if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name: + # out_dim, in_dim, filter -> filter, in_dim, out_dim + numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy() + elif 'lstm_cell' in tf_var.name and 'kernel' in tf_var.name: + numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() + # if variable is for bidirectional lstm and it is a bias vector there + # needs to be pre-defined two matching torch bias vectors + elif '_lstm/lstm_cell_' in tf_var.name and 'bias' in tf_var.name: + bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name] + assert len(bias_vectors) == 2 + numpy_weight = bias_vectors[0] + bias_vectors[1] + elif 'rnn' in tf_var.name and 'kernel' in tf_var.name: + numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() + elif 'rnn' in tf_var.name and 'bias' in tf_var.name: + bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key] + assert len(bias_vectors) == 2 + numpy_weight = bias_vectors[0] + bias_vectors[1] + elif 'linear_layer' in torch_layer_name and 'weight' in torch_var_name: + numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() + else: + numpy_weight = torch_weight.detach().cpu().numpy() + assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" + tf.keras.backend.set_value(tf_var, numpy_weight) + + +def load_tf_vars(model_tf, tf_vars): + for tf_var in tf_vars: + model_tf.get_layer(tf_var.name).set_weights(tf_var) + return model_tf diff --git a/tf/utils/generic_utils.py b/tf/utils/generic_utils.py new file mode 100644 index 00000000..3ef10a62 --- /dev/null +++ b/tf/utils/generic_utils.py @@ -0,0 +1,105 @@ +import os +import re +import glob +import shutil +import datetime +import json +import subprocess +import importlib +import pickle +import numpy as np +from collections import OrderedDict, Counter +import tensorflow as tf + + +def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs): + checkpoint_path = 'tts_tf_checkpoint_{}.pkl'.format(current_step) + checkpoint_path = os.path.join(output_folder, checkpoint_path) + state = { + 'model': model.weights, + 'optimizer': optimizer, + 'step': current_step, + 'epoch': epoch, + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': r + } + state.update(kwargs) + pickle.dump(state, open(checkpoint_path, 'wb')) + + +def load_checkpoint(model, checkpoint_path): + checkpoint = pickle.load(open(checkpoint_path, 'rb')) + chkp_var_dict = dict([(var.name, var.numpy()) for var in checkpoint['model']]) + tf_vars = model.weights + for tf_var in tf_vars: + layer_name = tf_var.name + chkp_var_value = chkp_var_dict[layer_name] + tf.keras.backend.set_value(tf_var, chkp_var_value) + if 'r' in checkpoint.keys(): + model.decoder.set_r(checkpoint['r']) + return model + + +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.max() + batch_size = sequence_length.size(0) + seq_range = np.empty([0, max_len], dtype=np.int8) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = ( + sequence_length.unsqueeze(1).expand_as(seq_range_expand)) + # B x T_max + return seq_range_expand < seq_length_expand + + +# @tf.custom_gradient +def check_gradient(x, grad_clip): + x_normed = tf.clip_by_norm(x, grad_clip) + grad_norm = tf.norm(grad_clip) + return x_normed, grad_norm + + +def count_parameters(model, c): + try: + return model.count_params() + except: + input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32')) + input_lengths = np.random.randint(100, 129, (8, )) + input_lengths[-1] = 128 + input_lengths = tf.convert_to_tensor(input_lengths.astype('int32')) + mel_spec = np.random.rand(8, 2 * c.r, + c.audio['num_mels']).astype('float32') + mel_spec = tf.convert_to_tensor(mel_spec) + speaker_ids = np.random.randint( + 0, 5, (8, )) if c.use_speaker_embedding else None + _ = model(input_dummy, input_lengths, mel_spec) + return model.count_params() + + +def setup_model(num_chars, num_speakers, c): + print(" > Using model: {}".format(c.model)) + MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower()) + MyModel = getattr(MyModel, c.model) + if c.model.lower() in "tacotron": + raise NotImplemented(' [!] Tacotron model is not ready.') + elif c.model.lower() == "tacotron2": + model = MyModel(num_chars=num_chars, + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=c.audio['num_mels'], + decoder_output_dim=c.audio['num_mels'], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder) + return model diff --git a/tf/utils/tf_utils.py b/tf/utils/tf_utils.py new file mode 100644 index 00000000..558936d5 --- /dev/null +++ b/tf/utils/tf_utils.py @@ -0,0 +1,8 @@ +import tensorflow as tf + + +def shape_list(x): + """Deal with dynamic shape in tensorflow cleanly.""" + static = x.shape.as_list() + dynamic = tf.shape(x) + return [dynamic[i] if s is None else s for i, s in enumerate(static)] From 88053706450b3e964a972a209d3bc66270a9f7b6 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 18 May 2020 11:34:13 +0200 Subject: [PATCH 13/16] add tf tacotron2 test and edit test utils imports after utils refactoring --- tests/test_demo_server.py | 3 +- tests/test_loader.py | 2 +- tests/test_tacotron2_model.py | 2 +- tests/test_tacotron2_tf_model.py | 59 ++++++++++++++++++++++++++++++++ tests/test_tacotron_model.py | 2 +- tests/test_text_processing.py | 4 +-- 6 files changed, 66 insertions(+), 6 deletions(-) create mode 100644 tests/test_tacotron2_tf_model.py diff --git a/tests/test_demo_server.py b/tests/test_demo_server.py index a0837686..11d16a45 100644 --- a/tests/test_demo_server.py +++ b/tests/test_demo_server.py @@ -6,7 +6,8 @@ import torch as T from TTS.server.synthesizer import Synthesizer from TTS.tests import get_tests_input_path, get_tests_output_path from TTS.utils.text.symbols import make_symbols, phonemes, symbols -from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model +from TTS.utils.generic_utils import setup_model +from TTS.utils.io import load_config, save_checkpoint class DemoServerTest(unittest.TestCase): diff --git a/tests/test_loader.py b/tests/test_loader.py index 447c7b38..9edd233f 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -5,7 +5,7 @@ import torch import numpy as np from torch.utils.data import DataLoader -from TTS.utils.generic_utils import load_config +from TTS.utils.io import load_config from TTS.utils.audio import AudioProcessor from TTS.datasets import TTSDataset from TTS.datasets.preprocess import ljspeech diff --git a/tests/test_tacotron2_model.py b/tests/test_tacotron2_model.py index aa2869eb..eb91b3cc 100644 --- a/tests/test_tacotron2_model.py +++ b/tests/test_tacotron2_model.py @@ -6,7 +6,7 @@ import numpy as np from torch import optim from torch import nn -from TTS.utils.generic_utils import load_config +from TTS.utils.io import load_config from TTS.layers.losses import MSELossMasked from TTS.models.tacotron2 import Tacotron2 diff --git a/tests/test_tacotron2_tf_model.py b/tests/test_tacotron2_tf_model.py new file mode 100644 index 00000000..27398748 --- /dev/null +++ b/tests/test_tacotron2_tf_model.py @@ -0,0 +1,59 @@ +import os +import copy +import torch +import unittest +import numpy as np +import tensorflow as tf + +from torch import optim +from torch import nn +from TTS.utils.io import load_config +from TTS.layers.losses import MSELossMasked +from TTS.tf.models.tacotron2 import Tacotron2 + +#pylint: disable=unused-variable + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +file_path = os.path.dirname(os.path.realpath(__file__)) +c = load_config(os.path.join(file_path, 'test_config.json')) + + +class TacotronTFTrainTest(unittest.TestCase): + def test_train_step(self): + ''' test forward pass ''' + input = torch.randint(0, 24, (8, 128)).long().to(device) + input_lengths = torch.randint(100, 128, (8, )).long().to(device) + input_lengths = torch.sort(input_lengths, descending=True)[0] + mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + stop_targets = torch.zeros(8, 30, 1).float().to(device) + speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + input = tf.convert_to_tensor(input.cpu().numpy()) + input_lengths = tf.convert_to_tensor(input_lengths.cpu().numpy()) + mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy()) + + for idx in mel_lengths: + stop_targets[:, int(idx.item()):, 0] = 1.0 + + stop_targets = stop_targets.view(input.shape[0], + stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + + model = Tacotron2(num_chars=24, r=c.r, num_speakers=5) + # training pass + output = model(input, input_lengths, mel_spec, training=True) + + # check model output shapes + assert np.all(output[0].shape == mel_spec.shape) + assert np.all(output[1].shape == mel_spec.shape) + assert output[2].shape[2] == input.shape[1] + assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) + assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) + + # inference pass + output = model(input, training=False) diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index ac6712b0..7053a580 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -5,7 +5,7 @@ import unittest from torch import optim from torch import nn -from TTS.utils.generic_utils import load_config +from TTS.utils.io import load_config from TTS.layers.losses import L1LossMasked from TTS.models.tacotron import Tacotron diff --git a/tests/test_text_processing.py b/tests/test_text_processing.py index 6c0c7058..93edabe7 100644 --- a/tests/test_text_processing.py +++ b/tests/test_text_processing.py @@ -5,7 +5,7 @@ import os import unittest from TTS.utils.text import * from TTS.tests import get_tests_path -from TTS.utils.generic_utils import load_config +from TTS.utils.io import load_config TESTS_PATH = get_tests_path() conf = load_config(os.path.join(TESTS_PATH, 'test_config.json')) @@ -92,4 +92,4 @@ def test_text2phone(): gt = "ɹ|iː|s|ə|n|t| |ɹ|ɪ|s|ɜː|tʃ| |æ|t| |h|ɑːɹ|v|ɚ|d| |h|ɐ|z| |ʃ|oʊ|n| |m|ɛ|d|ᵻ|t|eɪ|ɾ|ɪ|ŋ| |f|ɔː|ɹ| |æ|z| |l|ɪ|ɾ|əl| |æ|z| |eɪ|t| |w|iː|k|s| |k|æ|n| |æ|k|tʃ|uː|əl|i| |ɪ|n|k|ɹ|iː|s|,| |ð|ə| |ɡ|ɹ|eɪ| |m|æ|ɾ|ɚ|ɹ| |ɪ|n|ð|ə| |p|ɑːɹ|t|s| |ʌ|v|ð|ə| |b|ɹ|eɪ|n| |ɹ|ɪ|s|p|ɑː|n|s|ə|b|əl| |f|ɔː|ɹ| |ɪ|m|oʊ|ʃ|ə|n|əl| |ɹ|ɛ|ɡ|j|uː|l|eɪ|ʃ|ə|n| |æ|n|d| |l|ɜː|n|ɪ|ŋ|!" lang = "en-us" ph = text2phone(text, lang) - assert gt == ph, f"\n{phonemes} \n vs \n{gt}" \ No newline at end of file + assert gt == ph, f"\n{phonemes} \n vs \n{gt}" From 523fa5dfd2c183ad030bbcbf4ea37f0a3ae82653 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 18 May 2020 11:35:19 +0200 Subject: [PATCH 14/16] pass sequence mask to the same device as the input --- utils/generic_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 9685f463..c81fde49 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -99,7 +99,7 @@ def sequence_mask(sequence_length, max_len=None): seq_range = torch.arange(0, max_len).long() seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) if sequence_length.is_cuda: - seq_range_expand = seq_range_expand.cuda() + seq_range_expand = seq_range_expand.to(sequence_length.device) seq_length_expand = ( sequence_length.unsqueeze(1).expand_as(seq_range_expand)) # B x T_max From 8e6aedccee7df04b405a70a6f190232e816b6c4c Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 18 May 2020 12:00:10 +0200 Subject: [PATCH 15/16] update readme --- tf/README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tf/README.md b/tf/README.md index 24e09a06..04b6936c 100644 --- a/tf/README.md +++ b/tf/README.md @@ -1,4 +1,12 @@ ## Utilities to Convert Models to Tensorflow2 -You can find some utilities to convert Torch models to Tensorflow with an experimental Tacotron2 implemenation in Tensorflow2 (>=2.2). However, our released Torch models may not work with this module due to additional changes layer naming convention. Therefore, you need to train new models to be converted to TF. +Here there are utilities to convert trained Torch models to Tensorflow (2.2>=). -This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own. \ No newline at end of file +We currently support Tacotron2 with Location Sensitive Attention. + +Be aware that our old Torch models may not work with this module due to additional changes in layer naming convention. Therefore, you need to train new models or handle these changes. + +We do not plan to share training scripts for Tensorflow in near future. But any contribution in that direction would be more than welcome. + +To see how you can use TF model at inference, check the notebook. + +This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own. From 342d6303d41a5a6752db8e91ebdc923ef3eebc95 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 18 May 2020 12:20:51 +0200 Subject: [PATCH 16/16] update TF model notebook --- tf/layers/tacotron2.py | 2 +- tf/notebooks/Benchmark-TTS_tf.ipynb | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tf/layers/tacotron2.py b/tf/layers/tacotron2.py index 4d787e83..b8e18cb1 100644 --- a/tf/layers/tacotron2.py +++ b/tf/layers/tacotron2.py @@ -228,4 +228,4 @@ class Decoder(keras.layers.Layer): def call(self, memory, states, frames=None, memory_seq_length=None, training=False): if training: return self.decode(memory, states, frames, memory_seq_length) - return self.decode_inference(memory, states) \ No newline at end of file + return self.decode_inference(memory, states) diff --git a/tf/notebooks/Benchmark-TTS_tf.ipynb b/tf/notebooks/Benchmark-TTS_tf.ipynb index 5531460e..c2b634e6 100644 --- a/tf/notebooks/Benchmark-TTS_tf.ipynb +++ b/tf/notebooks/Benchmark-TTS_tf.ipynb @@ -10,15 +10,14 @@ "\n", "Before running this script please DON'T FORGET: \n", "- to set file paths.\n", - "- to download related model files from TTS and PWGAN.\n", + "- to download related model files.\n", "- download or clone related repos, linked below.\n", "- setup the repositories. ```python setup.py install```\n", - "- to checkout right commit versions (given next to the model) of TTS and PWGAN.\n", - "- to set the right paths in the cell below.\n", + "- to checkout right commit versions (given next to the model in the models page).\n", + "- to set the file paths below.\n", "\n", "Repositories:\n", - "- TTS: https://github.com/mozilla/TTS\n", - "- PWGAN: https://github.com/erogol/ParallelWaveGAN" + "- TTS: https://github.com/mozilla/TTS" ] }, { @@ -151,7 +150,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "Collapsed": "false" + "Collapsed": "false", + "scrolled": true }, "outputs": [], "source": [