diff --git a/tf/README.md b/tf/README.md new file mode 100644 index 00000000..24e09a06 --- /dev/null +++ b/tf/README.md @@ -0,0 +1,4 @@ +## Utilities to Convert Models to Tensorflow2 +You can find some utilities to convert Torch models to Tensorflow with an experimental Tacotron2 implemenation in Tensorflow2 (>=2.2). However, our released Torch models may not work with this module due to additional changes layer naming convention. Therefore, you need to train new models to be converted to TF. + +This is an experimental release. If you encounter an error, please put an issue or in the best send a PR but you are mostly on your own. \ No newline at end of file diff --git a/tf/convert_tacotron2_torch_to_tf.py b/tf/convert_tacotron2_torch_to_tf.py new file mode 100644 index 00000000..512b0a4d --- /dev/null +++ b/tf/convert_tacotron2_torch_to_tf.py @@ -0,0 +1,196 @@ +# %% +import sys +sys.path.append('/home/erogol/Projects') +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '' +# %% +import argparse +import numpy as np +import torch +import tensorflow as tf +from fuzzywuzzy import fuzz + +from TTS.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.utils.generic_utils import setup_model, count_parameters +from TTS.utils.io import load_config +from TTS_tf.models.tacotron2 import Tacotron2 +from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name +from TTS_tf.utils.generic_utils import save_checkpoint + + +parser = argparse.ArgumentParser() +parser.add_argument( + '--torch_model_path', + type=str, + help='Path to target torch model to be converted to TF.') +parser.add_argument( + '--config_path', + type=str, + help='Path to config file of torch model.') +parser.add_argument( + '--output_path', + type=str, + help='path to save TF model weights.') +args = parser.parse_args() + +# load model config +config_path = args.config_path +c = load_config(config_path) +num_speakers = 0 + +# init torch model +num_chars = len(phonemes) if c.use_phonemes else len(symbols) +model = setup_model(num_chars, num_speakers, c) +checkpoint = torch.load(args.torch_model_path, map_location=torch.device('cpu')) +state_dict = checkpoint['model'] +model.load_state_dict(state_dict) + +# init tf model +model_tf = Tacotron2(num_chars=num_chars, + num_speakers=num_speakers, + r=model.decoder.r, + postnet_output_dim=c.audio['num_mels'], + decoder_output_dim=c.audio['num_mels'], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder) + +# set initial layer mapping - these are not captured by the below heuristic approach +# TODO: set layer names so that we can remove these manual matching +common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' +var_map = [ + ('tacotron2/embedding/embeddings:0', 'embedding.weight'), + ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', 'encoder.lstm.weight_ih_l0'), + ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0'), + ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', 'encoder.lstm.weight_ih_l0_reverse'), + ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0_reverse'), + ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0', ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')), + ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0', ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')), + ('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'), + ('decoder/linear_projection/kernel:0', 'decoder.linear_projection.linear_layer.weight'), + ('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight') +] + + +# %% +# get tf_model graph +input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs() +mel_pred = model_tf(input_ids, training=False) + +# get tf variables +tf_vars = model_tf.weights + +# match variable names with fuzzy logic +torch_var_names = list(state_dict.keys()) +tf_var_names = [we.name for we in model_tf.weights] +for tf_name in tf_var_names: + # skip re-mapped layer names + if tf_name in [name[0] for name in var_map]: + continue + tf_name_edited = convert_tf_name(tf_name) + ratios = [fuzz.ratio(torch_name, tf_name_edited) for torch_name in torch_var_names] + max_idx = np.argmax(ratios) + matching_name = torch_var_names[max_idx] + del torch_var_names[max_idx] + var_map.append((tf_name, matching_name)) + + +# %% +# print variable match +from pprint import pprint +pprint(var_map) +pprint(torch_var_names) + +# pass weights +tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict) + +# Compare TF and TORCH models +# %% +# check embedding outputs +model.eval() +input_ids = torch.randint(0, 24, (1, 128)).long() + +o_t = model.embedding(input_ids) +o_tf = model_tf.embedding(input_ids.detach().numpy()) +assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum() + +# compare encoder outputs +oo_en = model.encoder.inference(o_t.transpose(1,2)) +ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False) +assert compare_torch_tf(oo_en, ooo_en) < 1e-5 + +# compare decoder.attention_rnn +inp = torch.rand([1, 768]) +inp_tf = inp.numpy() +model.decoder._init_states(oo_en, mask=None) +output, cell_state = model.decoder.attention_rnn(inp) +states = model_tf.decoder.build_decoder_initial_states(1,512,128) +output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False) +assert compare_torch_tf(output, output_tf).mean() < 1e-5 + +# compare decoder.attention +query = output +inputs = torch.rand([1, 128, 512]) +query_tf = query.detach().numpy() +inputs_tf = inputs.numpy() + +model.decoder.attention.init_states(inputs) +processes_inputs = model.decoder.attention.preprocess_inputs(inputs) +loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs) +context = model.decoder.attention(query, inputs, processes_inputs, None) + +model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf)) +loc_attn_tf, proc_query_tf = model_tf.decoder.attention.get_loc_attn(query_tf) +context_tf = model_tf.decoder.attention(query_tf, training=False) + +assert compare_torch_tf(loc_attn, loc_attn_tf).mean() < 1e-5 +assert compare_torch_tf(proc_query, proc_query_tf).mean() < 1e-5 +assert compare_torch_tf(context, context_tf) < 1e-5 + +# compare decoder.decoder_rnn +input = torch.rand([1, 1536]) +input_tf = input.numpy() +model.decoder._init_states(oo_en, mask=None) +output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) +states = model_tf.decoder.build_decoder_initial_states(1,512,128) +output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False) +assert abs(input - input_tf).mean() < 1e-5 +assert compare_torch_tf(output, output_tf).mean() < 1e-5 + +# compare decoder.linear_projection +input = torch.rand([1, 1536]) +input_tf = input.numpy() +output = model.decoder.linear_projection(input) +output_tf = model_tf.decoder.linear_projection(input_tf, training=False) +assert compare_torch_tf(output, output_tf) < 1e-5 + +# compare decoder outputs +model.decoder.max_decoder_steps = 100 +model_tf.decoder.set_max_decoder_steps(100) +output, align, stop = model.decoder.inference(oo_en) +states = model_tf.decoder.build_decoder_initial_states(1,512,128) +output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False) +assert compare_torch_tf(output.transpose(1,2), output_tf) < 1e-4 + +# compare the whole model output +outputs_torch = model.inference(input_ids) +outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy())) +print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean() ) +assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5 +assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4 + +# %% +# save tf model +save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'], + checkpoint['r'], args.output_path) +print(' > Model conversion is successfully completed :).') + diff --git a/tf/layers/common_layers.py b/tf/layers/common_layers.py new file mode 100644 index 00000000..fba06e0b --- /dev/null +++ b/tf/layers/common_layers.py @@ -0,0 +1,258 @@ +import tensorflow as tf +from tensorflow import keras +from tensorflow.python.ops import math_ops +# from tensorflow_addons.seq2seq import BahdanauAttention + +from TTS.tf.utils.tf_utils import shape_list + + +class Linear(keras.layers.Layer): + def __init__(self, units, use_bias, **kwargs): + super(Linear, self).__init__(**kwargs) + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') + self.activation = keras.layers.ReLU() + + def call(self, x, training=None): + """ + shapes: + x: B x T x C + """ + return self.activation(self.linear_layer(x)) + + +class LinearBN(keras.layers.Layer): + def __init__(self, units, use_bias, **kwargs): + super(LinearBN, self).__init__(**kwargs) + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') + self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization') + self.activation = keras.layers.ReLU() + + def call(self, x, training=None): + """ + shapes: + x: B x T x C + """ + out = self.linear_layer(x) + out = self.batch_normalization(out, training=training) + return self.activation(out) + + +class Prenet(keras.layers.Layer): + def __init__(self, + prenet_type, + prenet_dropout, + units, + bias, + **kwargs): + super(Prenet, self).__init__(**kwargs) + self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout + self.linear_layers = [] + if prenet_type == "bn": + self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + elif prenet_type == "original": + self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + else: + raise RuntimeError(' [!] Unknown prenet type.') + if prenet_dropout: + self.dropout = keras.layers.Dropout(rate=0.5) + + def call(self, x, training=None): + """ + shapes: + x: B x T x C + """ + for linear in self.linear_layers: + if self.prenet_dropout: + x = self.dropout(linear(x), training=training) + else: + x = linear(x) + return x + + +def _sigmoid_norm(score): + attn_weights = tf.nn.sigmoid(score) + attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True) + return attn_weights + + +class Attention(keras.layers.Layer): + """TODO: implement forward_attention""" + """TODO: location sensitive attention""" + """TODO: implement attention windowing """ + def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, + loc_attn_kernel_size, use_windowing, norm, use_forward_attn, + use_trans_agent, use_forward_attn_mask, **kwargs): + super(Attention, self).__init__(**kwargs) + self.use_loc_attn = use_loc_attn + self.loc_attn_n_filters = loc_attn_n_filters + self.loc_attn_kernel_size = loc_attn_kernel_size + self.use_windowing = use_windowing + self.norm = norm + self.use_forward_attn = use_forward_attn + self.use_trans_agent = use_trans_agent + self.use_forward_attn_mask = use_forward_attn_mask + self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer') + self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer') + self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer') + if use_loc_attn: + self.location_conv1d = keras.layers.Conv1D( + filters=loc_attn_n_filters, + kernel_size=loc_attn_kernel_size, + padding='same', + use_bias=False, + name='location_layer/location_conv1d') + self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense') + if norm == 'softmax': + self.norm_func = tf.nn.softmax + elif norm == 'sigmoid': + self.norm_func = _sigmoid_norm + else: + raise ValueError("Unknown value for attention norm type") + + def init_states(self, batch_size, value_length): + states = () + if self.use_loc_attn: + attention_cum = tf.zeros([batch_size, value_length]) + attention_old = tf.zeros([batch_size, value_length]) + states = (attention_cum, attention_old) + return states + + def process_values(self, values): + """ cache values for decoder iterations """ + self.processed_values = self.inputs_layer(values) + self.values = values + + def get_loc_attn(self, query, states): + """ compute location attention, query layer and + unnorm. attention weights""" + attention_cum, attention_old = states + attn_cat = tf.stack([attention_old, attention_cum], + axis=2) + + processed_query = self.query_layer(tf.expand_dims(query, 1)) + processed_attn = self.location_dense(self.location_conv1d(attn_cat)) + score = self.v( + tf.nn.tanh(self.processed_values + processed_query + + processed_attn)) + score = tf.squeeze(score, axis=2) + return score, processed_query + + def get_attn(self, query): + """ compute query layer and unnormalized attention weights """ + processed_query = self.query_layer(tf.expand_dims(query, 1)) + score = self.v(tf.nn.tanh(self.processed_values + processed_query)) + score = tf.squeeze(score, axis=2) + return score, processed_query + + def apply_score_masking(self, score, mask): + """ ignore sequence paddings """ + padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) + # Bias so padding positions do not contribute to attention distribution. + score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) + return score + + def call(self, query, states): + """ + shapes: + query: B x D + """ + if self.use_loc_attn: + score, processed_query = self.get_loc_attn(query, states) + else: + score, processed_query = self.get_attn(query) + + # TODO: masking + # if mask is not None: + # self.apply_score_masking(score, mask) + # attn_weights shape == (batch_size, max_length, 1) + + attn_weights = self.norm_func(score) + + # update attention states + if self.use_loc_attn: + states = (states[0] + attn_weights, attn_weights) + else: + states = () + + # context_vector shape after sum == (batch_size, hidden_size) + context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False) + context_vector = tf.squeeze(context_vector, axis=1) + return context_vector, attn_weights, states + + +# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b): +# dtype = processed_query.dtype +# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1] +# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2]) + + +# class LocationSensitiveAttention(BahdanauAttention): +# def __init__(self, +# units, +# memory=None, +# memory_sequence_length=None, +# normalize=False, +# probability_fn="softmax", +# kernel_initializer="glorot_uniform", +# dtype=None, +# name="LocationSensitiveAttention", +# location_attention_filters=32, +# location_attention_kernel_size=31): + +# super(LocationSensitiveAttention, +# self).__init__(units=units, +# memory=memory, +# memory_sequence_length=memory_sequence_length, +# normalize=normalize, +# probability_fn='softmax', ## parent module default +# kernel_initializer=kernel_initializer, +# dtype=dtype, +# name=name) +# if probability_fn == 'sigmoid': +# self.probability_fn = lambda score, _: self._sigmoid_normalization(score) +# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False) +# self.location_dense = keras.layers.Dense(units, use_bias=False) +# # self.v = keras.layers.Dense(1, use_bias=True) + +# def _location_sensitive_score(self, processed_query, keys, processed_loc): +# processed_query = tf.expand_dims(processed_query, 1) +# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2]) + +# def _location_sensitive(self, alignment_cum, alignment_old): +# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2) +# return self.location_dense(self.location_conv(alignment_cat)) + +# def _sigmoid_normalization(self, score): +# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True) + +# # def _apply_masking(self, score, mask): +# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) +# # # Bias so padding positions do not contribute to attention distribution. +# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) +# # return score + +# def _calculate_attention(self, query, state): +# alignment_cum, alignment_old = state[:2] +# processed_query = self.query_layer( +# query) if self.query_layer else query +# processed_loc = self._location_sensitive(alignment_cum, alignment_old) +# score = self._location_sensitive_score( +# processed_query, +# self.keys, +# processed_loc) +# alignment = self.probability_fn(score, state) +# alignment_cum = alignment_cum + alignment +# state[0] = alignment_cum +# state[1] = alignment +# return alignment, state + +# def compute_context(self, alignments): +# expanded_alignments = tf.expand_dims(alignments, 1) +# context = tf.matmul(expanded_alignments, self.values) +# context = tf.squeeze(context, [1]) +# return context + +# # def call(self, query, state): +# # alignment, next_state = self._calculate_attention(query, state) +# # return alignment, next_state diff --git a/tf/layers/tacotron2.py b/tf/layers/tacotron2.py new file mode 100644 index 00000000..4d787e83 --- /dev/null +++ b/tf/layers/tacotron2.py @@ -0,0 +1,231 @@ + +import tensorflow as tf +from tensorflow import keras +from TTS.tf.utils.tf_utils import shape_list +from TTS.tf.layers.common_layers import Prenet, Attention +# from tensorflow_addons.seq2seq import AttentionWrapper + + +class ConvBNBlock(keras.layers.Layer): + def __init__(self, filters, kernel_size, activation, **kwargs): + super(ConvBNBlock, self).__init__(**kwargs) + self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d') + self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization') + self.dropout = keras.layers.Dropout(rate=0.5, name='dropout') + self.activation = keras.layers.Activation(activation, name='activation') + + def call(self, x, training=None): + o = self.convolution1d(x) + o = self.batch_normalization(o, training=training) + o = self.activation(o) + o = self.dropout(o, training=training) + return o + + +class Postnet(keras.layers.Layer): + def __init__(self, output_filters, num_convs, **kwargs): + super(Postnet, self).__init__(**kwargs) + self.convolutions = [] + self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0')) + for idx in range(1, num_convs - 1): + self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}')) + self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}')) + + def call(self, x, training=None): + o = x + for layer in self.convolutions: + o = layer(o, training=training) + return o + + +class Encoder(keras.layers.Layer): + def __init__(self, output_input_dim, **kwargs): + super(Encoder, self).__init__(**kwargs) + self.convolutions = [] + for idx in range(3): + self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}')) + self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm') + + def call(self, x, training=None): + o = x + for layer in self.convolutions: + o = layer(o, training=training) + o = self.lstm(o) + return o + + +class Decoder(keras.layers.Layer): + def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, + prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, + use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs): + super(Decoder, self).__init__(**kwargs) + self.frame_dim = frame_dim + self.r_init = tf.constant(r, dtype=tf.int32) + self.r = tf.constant(r, dtype=tf.int32) + self.separate_stopnet = separate_stopnet + self.max_decoder_steps = tf.constant(1000, dtype=tf.int32) + self.stop_thresh = tf.constant(0.5, dtype=tf.float32) + + # model dimensions + self.query_dim = 1024 + self.decoder_rnn_dim = 1024 + self.prenet_dim = 256 + self.attn_dim = 128 + self.p_attention_dropout = 0.1 + self.p_decoder_dropout = 0.1 + + self.prenet = Prenet(prenet_type, + prenet_dropout, + [self.prenet_dim, self.prenet_dim], + bias=False, + name='prenet') + self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', ) + self.attention_rnn_dropout = keras.layers.Dropout(0.5) + + # TODO: implement other attn options + self.attention = Attention(attn_dim=self.attn_dim, + use_loc_attn=True, + loc_attn_n_filters=32, + loc_attn_kernel_size=31, + use_windowing=False, + norm=attn_norm, + use_forward_attn=use_forward_attn, + use_trans_agent=use_trans_agent, + use_forward_attn_mask=use_forward_attn_mask, + name='attention') + self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn') + self.decoder_rnn_dropout = keras.layers.Dropout(0.5) + self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer') + self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer') + + + def set_max_decoder_steps(self, new_max_steps): + self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32) + + def set_r(self, new_r): + self.r = tf.constant(new_r, dtype=tf.int32) + + def build_decoder_initial_states(self, batch_size, memory_dim, memory_length): + zero_frame = tf.zeros([batch_size, self.frame_dim]) + zero_context = tf.zeros([batch_size, memory_dim]) + attention_rnn_state = self.attention_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) + decoder_rnn_state = self.decoder_rnn.get_initial_state(batch_size=batch_size, dtype=tf.float32) + attention_states = self.attention.init_states(batch_size, memory_length) + return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states + + def step(self, prenet_next, states, + memory_seq_length=None, training=None): + _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states + attention_rnn_input = tf.concat([prenet_next, context_next], -1) + attention_rnn_output, attention_rnn_state = \ + self.attention_rnn(attention_rnn_input, + attention_rnn_state, training=training) + attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training) + context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training) + decoder_rnn_input = tf.concat([attention_rnn_output, context], -1) + decoder_rnn_output, decoder_rnn_state = \ + self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training) + decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training) + linear_projection_input = tf.concat([decoder_rnn_output, context], -1) + output_frame = self.linear_projection(linear_projection_input, training=training) + stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1) + stopnet_output = self.stopnet(stopnet_input, training=training) + output_frame = output_frame[:, :self.r * self.frame_dim] + states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states) + return output_frame, stopnet_output, states, attention + + def decode(self, memory, states, frames, memory_seq_length=None): + B, T, D = shape_list(memory) + num_iter = shape_list(frames)[1] // self.r + # init states + frame_zero = tf.expand_dims(states[0], 1) + frames = tf.concat([frame_zero, frames], axis=1) + outputs = tf.TensorArray(dtype=tf.float32, size=num_iter) + attentions = tf.TensorArray(dtype=tf.float32, size=num_iter) + stop_tokens = tf.TensorArray(dtype=tf.float32, size=num_iter) + # pre-computes + self.attention.process_values(memory) + prenet_output = self.prenet(frames, training=True) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): + prenet_next = prenet_output[:, step] + output, stop_token, states, attention = self.step(prenet_next, + states, + memory_seq_length) + outputs = outputs.write(step, output) + attentions = attentions.write(step, attention) + stop_tokens = stop_tokens.write(step, stop_token) + return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions + _, memory, _, states, outputs, stop_tokens, attentions = \ + tf.while_loop(lambda *arg: True, + _body, + loop_vars=(step_count, memory, prenet_output, states, outputs, + stop_tokens, attentions), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=num_iter) + + outputs = outputs.stack() + attentions = attentions.stack() + stop_tokens = stop_tokens.stack() + outputs = tf.transpose(outputs, [1, 0, 2]) + attentions = tf.transpose(attentions, [1, 0 ,2]) + stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) + stop_tokens = tf.squeeze(stop_tokens, axis=2) + outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + def decode_inference(self, memory, states): + B, T, D = shape_list(memory) + # init states + outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) + # pre-computes + self.attention.process_values(memory) + + # iter vars + stop_flag = tf.constant(False, dtype=tf.bool) + step_count = tf.constant(0, dtype=tf.int32) + + def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): + frame_next = states[0] + prenet_next = self.prenet(frame_next, training=False) + output, stop_token, states, attention = self.step(prenet_next, + states, + None, + training=False) + stop_token = tf.math.sigmoid(stop_token) + outputs = outputs.write(step, output) + attentions = attentions.write(step, attention) + stop_tokens = stop_tokens.write(step, stop_token) + stop_flag = tf.greater(stop_token, self.stop_thresh) + stop_flag = tf.reduce_all(stop_flag) + return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag + + cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) + _, memory, states, outputs, stop_tokens, attentions, stop_flag = \ + tf.while_loop(cond, + _body, + loop_vars=(step_count, memory, states, outputs, + stop_tokens, attentions, stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps) + + outputs = outputs.stack() + attentions = attentions.stack() + stop_tokens = stop_tokens.stack() + + outputs = tf.transpose(outputs, [1, 0, 2]) + attentions = tf.transpose(attentions, [1, 0, 2]) + stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) + stop_tokens = tf.squeeze(stop_tokens, axis=2) + outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) + return outputs, stop_tokens, attentions + + def call(self, memory, states, frames=None, memory_seq_length=None, training=False): + if training: + return self.decode(memory, states, frames, memory_seq_length) + return self.decode_inference(memory, states) \ No newline at end of file diff --git a/tf/models/tacotron2.py b/tf/models/tacotron2.py new file mode 100644 index 00000000..8ddee666 --- /dev/null +++ b/tf/models/tacotron2.py @@ -0,0 +1,72 @@ +import tensorflow as tf +from tensorflow import keras + +from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet +from TTS.tf.utils.tf_utils import shape_list + + +class Tacotron2(keras.models.Model): + def __init__(self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type='original', + attn_win=False, + attn_norm="softmax", + attn_K=4, + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + separate_stopnet=True, + bidirectional_decoder=False): + super(Tacotron2, self).__init__() + self.r = r + self.decoder_output_dim = decoder_output_dim + self.postnet_output_dim = postnet_output_dim + self.bidirectional_decoder = bidirectional_decoder + self.num_speakers = num_speakers + self.speaker_embed_dim = 256 + + self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding') + self.encoder = Encoder(512, name='encoder') + # TODO: most of the decoder args have no use at the momment + self.decoder = Decoder(decoder_output_dim, r, attn_type=attn_type, use_attn_win=attn_win, attn_norm=attn_norm, prenet_type=prenet_type, + prenet_dropout=prenet_dropout, use_forward_attn=forward_attn, use_trans_agent=trans_agent, use_forward_attn_mask=forward_attn_mask, + use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim) + self.postnet = Postnet(postnet_output_dim, 5, name='postnet') + + def call(self, characters, text_lengths=None, frames=None, training=None): + if training == True: + return self.training(characters, text_lengths, frames) + else: + return self.inference(characters) + + def training(self, characters, text_lengths, frames): + B, T = shape_list(characters) + embedding_vectors = self.embedding(characters, training=True) + encoder_output = self.encoder(embedding_vectors, training=True) + decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) + decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True) + postnet_frames = self.postnet(decoder_frames, training=True) + output_frames = decoder_frames + postnet_frames + return decoder_frames, output_frames, attentions, stop_tokens + + def inference(self, characters): + B, T = shape_list(characters) + embedding_vectors = self.embedding(characters, training=False) + encoder_output = self.encoder(embedding_vectors, training=False) + decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) + decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False) + postnet_frames = self.postnet(decoder_frames, training=False) + output_frames = decoder_frames + postnet_frames + print(output_frames.shape) + return decoder_frames, output_frames, attentions, stop_tokens + + + + diff --git a/tf/notebooks/Benchmark-TTS_tf.ipynb b/tf/notebooks/Benchmark-TTS_tf.ipynb new file mode 100644 index 00000000..5531460e --- /dev/null +++ b/tf/notebooks/Benchmark-TTS_tf.ipynb @@ -0,0 +1,708 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "This is to test TTS models with benchmark sentences for speech synthesis.\n", + "\n", + "Before running this script please DON'T FORGET: \n", + "- to set file paths.\n", + "- to download related model files from TTS and PWGAN.\n", + "- download or clone related repos, linked below.\n", + "- setup the repositories. ```python setup.py install```\n", + "- to checkout right commit versions (given next to the model) of TTS and PWGAN.\n", + "- to set the right paths in the cell below.\n", + "\n", + "Repositories:\n", + "- TTS: https://github.com/mozilla/TTS\n", + "- PWGAN: https://github.com/erogol/ParallelWaveGAN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false", + "scrolled": true + }, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import os\n", + "\n", + "# you may need to change this depending on your system\n", + "os.environ['CUDA_VISIBLE_DEVICES']='1'\n", + "\n", + "import sys\n", + "import io\n", + "import torch \n", + "import tensorflow as tf\n", + "print(tf.config.list_physical_devices('GPU'))\n", + "\n", + "import time\n", + "import json\n", + "import yaml\n", + "import numpy as np\n", + "from collections import OrderedDict\n", + "import matplotlib.pyplot as plt\n", + "plt.rcParams[\"figure.figsize\"] = (16,5)\n", + "\n", + "import librosa\n", + "import librosa.display\n", + "\n", + "from TTS.tf.models.tacotron2 import Tacotron2\n", + "from TTS.tf.utils.generic_utils import setup_model, load_checkpoint\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.io import load_config\n", + "from TTS.utils.synthesis import synthesis\n", + "from TTS.utils.visual import visualize\n", + "\n", + "import IPython\n", + "from IPython.display import Audio\n", + "\n", + "%matplotlib agg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "def tts(model, text, CONFIG, use_cuda, ap, use_gl, figures=True):\n", + " t_1 = time.time()\n", + " waveform, alignment, mel_spec, mel_postnet_spec, stop_tokens, inputs = synthesis(model, text, CONFIG, use_cuda, ap, None, None, False, CONFIG.enable_eos_bos_chars, use_gl, backend=BACKEND)\n", + " if CONFIG.model == \"Tacotron\" and not use_gl:\n", + " # coorect the normalization differences b/w TTS and the Vocoder.\n", + " mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T\n", + " print(mel_postnet_spec.shape)\n", + " print(\"max- \", mel_postnet_spec.max(), \" -- min- \", mel_postnet_spec.min())\n", + " if not use_gl:\n", + " waveform = vocoder_model.inference(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0))\n", + " mel_postnet_spec = ap._denormalize(mel_postnet_spec.T).T\n", + " if use_cuda and not use_gl:\n", + " waveform = waveform.cpu()\n", + " waveform = waveform.numpy()\n", + " waveform = waveform.squeeze()\n", + " rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)\n", + " print(waveform.shape)\n", + " print(\" > Run-time: {}\".format(time.time() - t_1))\n", + " print(\" > Real-time factor: {}\".format(rtf))\n", + " if figures: \n", + " visualize(alignment, mel_postnet_spec, stop_tokens, text, ap.hop_length, CONFIG, ap._denormalize(mel_spec.T).T) \n", + " IPython.display.display(Audio(waveform, rate=CONFIG.audio['sample_rate'], normalize=True)) \n", + " os.makedirs(OUT_FOLDER, exist_ok=True)\n", + " file_name = text.replace(\" \", \"_\").replace(\".\",\"\") + \".wav\"\n", + " out_path = os.path.join(OUT_FOLDER, file_name)\n", + " ap.save_wav(waveform, out_path)\n", + " return alignment, mel_postnet_spec, stop_tokens, waveform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# Set constants\n", + "ROOT_PATH = '../tf_model/'\n", + "MODEL_PATH = ROOT_PATH + '/tts_tf_checkpoint_360000.pkl'\n", + "CONFIG_PATH = ROOT_PATH + '/config.json'\n", + "OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n", + "CONFIG = load_config(CONFIG_PATH)\n", + "# Run FLAGs\n", + "use_cuda = True\n", + "# Set the vocoder\n", + "use_gl = True # use GL if True\n", + "BACKEND = 'tf'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false", + "scrolled": true + }, + "outputs": [], + "source": [ + "from TTS.utils.text.symbols import symbols, phonemes, make_symbols\n", + "from TTS.tf.utils.convert_torch_to_tf_utils import tf_create_dummy_inputs\n", + "c = CONFIG\n", + "num_speakers = 0\n", + "r = 1\n", + "num_chars = len(phonemes) if c.use_phonemes else len(symbols)\n", + "model = setup_model(num_chars, num_speakers, c)\n", + "\n", + "# before loading weights you need to run the model once to generate the variables\n", + "input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()\n", + "mel_pred = model(input_ids, training=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "model = load_checkpoint(model, MODEL_PATH)\n", + "# model = tf.function(model, experimental_relax_shapes=True)\n", + "ap = AudioProcessor(**CONFIG.audio) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# wrapper class to use tf.function\n", + "class ModelInference(tf.keras.Model):\n", + " def __init__(self, model):\n", + " super(ModelInference, self).__init__()\n", + " self.model = model\n", + " \n", + " @tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32)])\n", + " def call(self, characters):\n", + " return self.model(characters, training=False)\n", + " \n", + "model = ModelInference(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# LOAD WAVERNN\n", + "if use_gl == False:\n", + " from parallel_wavegan.models import ParallelWaveGANGenerator, MelGANGenerator\n", + " \n", + " vocoder_model = MelGANGenerator(**VOCODER_CONFIG[\"generator_params\"])\n", + " vocoder_model.load_state_dict(torch.load(VOCODER_MODEL_PATH, map_location=\"cpu\")[\"model\"][\"generator\"])\n", + " vocoder_model.remove_weight_norm()\n", + " ap_vocoder = AudioProcessor(**VOCODER_CONFIG['audio']) \n", + " if use_cuda:\n", + " vocoder_model.cuda()\n", + " vocoder_model.eval();\n", + " print(count_parameters(vocoder_model))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Comparision with https://mycroft.ai/blog/available-voices/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Bill got in the habit of asking himself “Is that thought true?” and if he wasn’t absolutely certain it was, he just let it go.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### https://espnet.github.io/icassp2020-tts/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The Commission also recommends\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"As a result of these studies, the planning document submitted by the Secretary of the Treasury to the Bureau of the Budget on August thirty-one.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The FBI now transmits information on all defectors, a category which would, of course, have included Oswald.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"they seem unduly restrictive in continuing to require some manifestation of animus against a Government official.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"and each agency given clear understanding of the assistance which the Secret Service expects.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Other examples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Be a voice, not an echo.\" # 'echo' is not in training set. \n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The human voice is the most perfect instrument of all.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"I'm sorry Dave. I'm afraid I can't do that.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"This cake is great. It's so delicious and moist.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Comparison with https://keithito.github.io/audio-samples/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Generative adversarial network or variational auto-encoder.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Scientists at the CERN laboratory say they have discovered a new particle.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Here’s a way to measure the acute emotional intelligence that has never gone out of style.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"President Trump met with other leaders at the Group of 20 conference.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The buses aren't the problem, they actually provide a solution.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Comparison with https://google.github.io/tacotron/publications/tacotron/index.html" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Generative adversarial network or variational auto-encoder.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Basilar membrane and otolaryngology are not auto-correlations.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \" He has read the whole thing.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"He reads books.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Thisss isrealy awhsome.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"This is your internet browser, Firefox.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"This is your internet browser Firefox.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"The quick brown fox jumps over the lazy dog.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Does the quick brown fox jump over the lazy dog?\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Eren, how are you?\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "Collapsed": "false" + }, + "source": [ + "### Hard Sentences" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Encouraged, he started with a minute a day.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"His meditation consisted of “body scanning” which involved focusing his mind and energy on each section of the body from head to toe .\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning . \"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"If he decided to watch TV he really watched it.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "sentence = \"Often we try to bring about change through sheer effort and we put all of our energy into a new initiative .\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "# for twb dataset\n", + "sentence = \"In our preparation for Easter, God in his providence offers us each year the season of Lent as a sacramental sign of our conversion.\"\n", + "align, spec, stop_tokens, wav = tts(model, sentence, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [ + "wavs = []\n", + "model.eval()\n", + "model.decoder.prenet.eval()\n", + "model.decoder.max_decoder_steps = 2000\n", + "# model.decoder.prenet.train()\n", + "speaker_id = None\n", + "sentence = '''This is App Store Optimization report.\n", + "The first tab on the report is App Details. App details report is updated weekly and Datetime column shows the latest report update date. The widget displays the app icon, respective app version, visual assets on the store, app description, latest app update date on the Appstore/Google PlayStore and what’s new section.\n", + "In App Details tab, you can see not only your app but all Delivery Hero apps since we think it can be inspiring to see the other apps, their description and screenshots. \n", + "Product name is the actual app name on the AppStore or Google Play Store.\n", + "Screenshot URLs column display the actual screenshots on the store for the current version. No resizing is done. If you click on the screenshot, you can see it in full-size.\n", + "Current release date show the latest app update date when the query is run. Here we see that Appetito24 Android is updated to app version 4.6.3.2 on 28th of March.\n", + "If the description is too long, clarisights is not able to display the full description; however, if you select description and current_release_date cells to copy and paste it to a text editor, you'll see the full description.\n", + "If you scroll down in the widget, you can see the older app versions for the same apps. Or you can filter Datetime to see a specific timeframe and the apps’ Store presence back then.\n", + "You can also filter for a specific app using Product Name.\n", + "If the description is too long, clarisights is not able to display the full description; however, if you select description and current_release_date cells to copy and paste it to a text editor, you'll see the full description.\n", + "'''\n", + "\n", + "for s in sentence.split('\\n'):\n", + " print(s)\n", + " align, spec, stop_tokens, wav = tts(model, s, CONFIG, use_cuda, ap, use_gl=use_gl, figures=True)\n", + " wavs = np.concatenate([wavs, np.zeros(int(ap.sample_rate * 0.5)), wav])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "Collapsed": "false" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tf/requirements b/tf/requirements new file mode 100644 index 00000000..75882a1d --- /dev/null +++ b/tf/requirements @@ -0,0 +1,2 @@ +fuzzywuzzy +tensorflow>=2.2.0 \ No newline at end of file diff --git a/tf/utils/convert_torch_to_tf_utils.py b/tf/utils/convert_torch_to_tf_utils.py new file mode 100644 index 00000000..732f2fb5 --- /dev/null +++ b/tf/utils/convert_torch_to_tf_utils.py @@ -0,0 +1,83 @@ +import numpy as np +import torch +import re +import tensorflow as tf +import tensorflow.keras.backend as K + + +def tf_create_dummy_inputs(): + """ Create dummy inputs for TF Tacotron2 model """ + batch_size = 4 + max_input_length = 32 + max_mel_length = 128 + pad = 1 + n_chars = 24 + input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32) + input_lengths = np.random.randint(0, high=max_input_length+1 + pad, size=[batch_size]) + input_lengths[-1] = max_input_length + input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32) + mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80]) + mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size]) + mel_lengths[-1] = max_mel_length + mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32) + return input_ids, input_lengths, mel_outputs, mel_lengths + + +def compare_torch_tf(torch_tensor, tf_tensor): + """ Compute the average absolute difference b/w torch and tf tensors """ + return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean() + + +def convert_tf_name(tf_name): + """ Convert certain patterns in TF layer names to Torch patterns """ + tf_name_tmp = tf_name + tf_name_tmp = tf_name_tmp.replace(':0', '') + tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0') + tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1') + tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh') + tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight') + tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight') + tf_name_tmp = tf_name_tmp.replace('/beta', '/bias') + tf_name_tmp = tf_name_tmp.replace('/', '.') + return tf_name_tmp + + +def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): + """ Transfer weigths from torch state_dict to TF variables """ + print(" > Passing weights from Torch to TF ...") + for tf_var in tf_vars: + torch_var_name = var_map_dict[tf_var.name] + print(f' | > {tf_var.name} <-- {torch_var_name}') + # if tuple, it is a bias variable + if type(torch_var_name) is not tuple: + torch_layer_name = '.'.join(torch_var_name.split('.')[-2:]) + torch_weight = state_dict[torch_var_name] + if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name: + # out_dim, in_dim, filter -> filter, in_dim, out_dim + numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy() + elif 'lstm_cell' in tf_var.name and 'kernel' in tf_var.name: + numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() + # if variable is for bidirectional lstm and it is a bias vector there + # needs to be pre-defined two matching torch bias vectors + elif '_lstm/lstm_cell_' in tf_var.name and 'bias' in tf_var.name: + bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name] + assert len(bias_vectors) == 2 + numpy_weight = bias_vectors[0] + bias_vectors[1] + elif 'rnn' in tf_var.name and 'kernel' in tf_var.name: + numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() + elif 'rnn' in tf_var.name and 'bias' in tf_var.name: + bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key] + assert len(bias_vectors) == 2 + numpy_weight = bias_vectors[0] + bias_vectors[1] + elif 'linear_layer' in torch_layer_name and 'weight' in torch_var_name: + numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() + else: + numpy_weight = torch_weight.detach().cpu().numpy() + assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" + tf.keras.backend.set_value(tf_var, numpy_weight) + + +def load_tf_vars(model_tf, tf_vars): + for tf_var in tf_vars: + model_tf.get_layer(tf_var.name).set_weights(tf_var) + return model_tf diff --git a/tf/utils/generic_utils.py b/tf/utils/generic_utils.py new file mode 100644 index 00000000..3ef10a62 --- /dev/null +++ b/tf/utils/generic_utils.py @@ -0,0 +1,105 @@ +import os +import re +import glob +import shutil +import datetime +import json +import subprocess +import importlib +import pickle +import numpy as np +from collections import OrderedDict, Counter +import tensorflow as tf + + +def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs): + checkpoint_path = 'tts_tf_checkpoint_{}.pkl'.format(current_step) + checkpoint_path = os.path.join(output_folder, checkpoint_path) + state = { + 'model': model.weights, + 'optimizer': optimizer, + 'step': current_step, + 'epoch': epoch, + 'date': datetime.date.today().strftime("%B %d, %Y"), + 'r': r + } + state.update(kwargs) + pickle.dump(state, open(checkpoint_path, 'wb')) + + +def load_checkpoint(model, checkpoint_path): + checkpoint = pickle.load(open(checkpoint_path, 'rb')) + chkp_var_dict = dict([(var.name, var.numpy()) for var in checkpoint['model']]) + tf_vars = model.weights + for tf_var in tf_vars: + layer_name = tf_var.name + chkp_var_value = chkp_var_dict[layer_name] + tf.keras.backend.set_value(tf_var, chkp_var_value) + if 'r' in checkpoint.keys(): + model.decoder.set_r(checkpoint['r']) + return model + + +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.max() + batch_size = sequence_length.size(0) + seq_range = np.empty([0, max_len], dtype=np.int8) + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = ( + sequence_length.unsqueeze(1).expand_as(seq_range_expand)) + # B x T_max + return seq_range_expand < seq_length_expand + + +# @tf.custom_gradient +def check_gradient(x, grad_clip): + x_normed = tf.clip_by_norm(x, grad_clip) + grad_norm = tf.norm(grad_clip) + return x_normed, grad_norm + + +def count_parameters(model, c): + try: + return model.count_params() + except: + input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32')) + input_lengths = np.random.randint(100, 129, (8, )) + input_lengths[-1] = 128 + input_lengths = tf.convert_to_tensor(input_lengths.astype('int32')) + mel_spec = np.random.rand(8, 2 * c.r, + c.audio['num_mels']).astype('float32') + mel_spec = tf.convert_to_tensor(mel_spec) + speaker_ids = np.random.randint( + 0, 5, (8, )) if c.use_speaker_embedding else None + _ = model(input_dummy, input_lengths, mel_spec) + return model.count_params() + + +def setup_model(num_chars, num_speakers, c): + print(" > Using model: {}".format(c.model)) + MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower()) + MyModel = getattr(MyModel, c.model) + if c.model.lower() in "tacotron": + raise NotImplemented(' [!] Tacotron model is not ready.') + elif c.model.lower() == "tacotron2": + model = MyModel(num_chars=num_chars, + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=c.audio['num_mels'], + decoder_output_dim=c.audio['num_mels'], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder) + return model diff --git a/tf/utils/tf_utils.py b/tf/utils/tf_utils.py new file mode 100644 index 00000000..558936d5 --- /dev/null +++ b/tf/utils/tf_utils.py @@ -0,0 +1,8 @@ +import tensorflow as tf + + +def shape_list(x): + """Deal with dynamic shape in tensorflow cleanly.""" + static = x.shape.as_list() + dynamic = tf.shape(x) + return [dynamic[i] if s is None else s for i, s in enumerate(static)]