From 6b2ff0823924d5a5a197f4f7a07b252590826c29 Mon Sep 17 00:00:00 2001 From: erogol Date: Fri, 19 Jun 2020 12:25:27 +0200 Subject: [PATCH] Issue #435 - Convert melgan vocoder models to TF2.0 --- vocoder/tf/convert_melgan_torch_to_tf.py | 115 ++++++++++++++++++ vocoder/tf/layers/melgan.py | 52 ++++++++ vocoder/tf/layers/pqmf.py | 66 ++++++++++ vocoder/tf/models/melgan_generator.py | 106 ++++++++++++++++ .../tf/models/multiband_melgan_generator.py | 46 +++++++ vocoder/tf/utils/__init__.py | 0 vocoder/tf/utils/convert_torch_to_tf_utils.py | 48 ++++++++ vocoder/tf/utils/generic_utils.py | 37 ++++++ vocoder/tf/utils/io.py | 27 ++++ 9 files changed, 497 insertions(+) create mode 100644 vocoder/tf/convert_melgan_torch_to_tf.py create mode 100644 vocoder/tf/layers/melgan.py create mode 100644 vocoder/tf/layers/pqmf.py create mode 100644 vocoder/tf/models/melgan_generator.py create mode 100644 vocoder/tf/models/multiband_melgan_generator.py create mode 100644 vocoder/tf/utils/__init__.py create mode 100644 vocoder/tf/utils/convert_torch_to_tf_utils.py create mode 100644 vocoder/tf/utils/generic_utils.py create mode 100644 vocoder/tf/utils/io.py diff --git a/vocoder/tf/convert_melgan_torch_to_tf.py b/vocoder/tf/convert_melgan_torch_to_tf.py new file mode 100644 index 00000000..b767f268 --- /dev/null +++ b/vocoder/tf/convert_melgan_torch_to_tf.py @@ -0,0 +1,115 @@ +import argparse +import os +import sys +from pprint import pprint + +import numpy as np +import tensorflow as tf +import torch +from fuzzywuzzy import fuzz + +from TTS.utils.io import load_config +from TTS.vocoder.tf.models.multiband_melgan_generator import \ + MultibandMelganGenerator +from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( + compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf) +from TTS.vocoder.tf.utils.generic_utils import \ + setup_generator as setup_tf_generator +from TTS.vocoder.tf.utils.io import save_checkpoint +from TTS.vocoder.utils.generic_utils import setup_generator + +# prevent GPU use +os.environ['CUDA_VISIBLE_DEVICES'] = '' + +# define args +parser = argparse.ArgumentParser() +parser.add_argument('--torch_model_path', + type=str, + help='Path to target torch model to be converted to TF.') +parser.add_argument('--config_path', + type=str, + help='Path to config file of torch model.') +parser.add_argument( + '--output_path', + type=str, + help='path to output file including file name to save TF model.') +args = parser.parse_args() + +# load model config +config_path = args.config_path +c = load_config(config_path) +num_speakers = 0 + +# init torch model +model = setup_generator(c) +checkpoint = torch.load(args.torch_model_path, + map_location=torch.device('cpu')) +state_dict = checkpoint['model'] +model.load_state_dict(state_dict) +model.remove_weight_norm() +state_dict = model.state_dict() + +# init tf model +model_tf = setup_tf_generator(c) + +common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' +# get tf_model graph by passing an input +# B x D x T +dummy_input = tf.random.uniform((7, 80, 64)) +mel_pred = model_tf(dummy_input, training=False) + +# get tf variables +tf_vars = model_tf.weights + +# match variable names with fuzzy logic +torch_var_names = list(state_dict.keys()) +tf_var_names = [we.name for we in model_tf.weights] +for tf_name in tf_var_names: + # skip re-mapped layer names + if tf_name in [name[0] for name in var_map]: + continue + tf_name_edited = convert_tf_name(tf_name) + ratios = [ + fuzz.ratio(torch_name, tf_name_edited) + for torch_name in torch_var_names + ] + max_idx = np.argmax(ratios) + matching_name = torch_var_names[max_idx] + del torch_var_names[max_idx] + var_map.append((tf_name, matching_name)) + +# pass weights +tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict) + +# Compare TF and TORCH models +# check embedding outputs +model.eval() +dummy_input_torch = torch.ones((1, 80, 10)) +dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy()) +dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1]) +dummy_input_tf = tf.expand_dims(dummy_input_tf, 2) + +out_torch = model.layers[0](dummy_input_torch) +out_tf = model_tf.model_layers[0](dummy_input_tf) +out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :] + +assert compare_torch_tf(out_torch, out_tf_) < 1e-5 + +for i in range(1, len(model.layers)): + print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}") + out_torch = model.layers[i](out_torch) + out_tf = model_tf.model_layers[i](out_tf) + out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :] + diff = compare_torch_tf(out_torch, out_tf_) + assert diff < 1e-5, diff + +dummy_input_torch = torch.ones((1, 80, 10)) +dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy()) +output_torch = model.inference(dummy_input_torch) +output_tf = model_tf(dummy_input_tf, training=False) +assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf( + output_torch, output_tf) +# save tf model +save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'], + args.output_path) +print(' > Model conversion is successfully completed :).') diff --git a/vocoder/tf/layers/melgan.py b/vocoder/tf/layers/melgan.py new file mode 100644 index 00000000..8625bbab --- /dev/null +++ b/vocoder/tf/layers/melgan.py @@ -0,0 +1,52 @@ +import tensorflow as tf + + +class ReflectionPad1d(tf.keras.layers.Layer): + def __init__(self, padding): + super(ReflectionPad1d, self).__init__() + self.padding = padding + + def call(self, x): + print(x.shape) + return tf.pad(x, [[0, 0], [self.padding, self.padding], [0, 0], [0, 0]], "REFLECT") + + +class ResidualStack(tf.keras.layers.Layer): + def __init__(self, channels, num_res_blocks, kernel_size, name): + super(ResidualStack, self).__init__(name=name) + + assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." + base_padding = (kernel_size - 1) // 2 + + self.blocks = [] + num_layers = 2 + for idx in range(num_res_blocks): + layer_kernel_size = kernel_size + layer_dilation = layer_kernel_size**idx + layer_padding = base_padding * layer_dilation + block = [ + tf.keras.layers.LeakyReLU(0.2), + ReflectionPad1d(layer_padding), + tf.keras.layers.Conv2D(filters=channels, + kernel_size=(kernel_size, 1), + dilation_rate=(layer_dilation, 1), + use_bias=True, + padding='valid', + name=f'blocks.{idx}.{num_layers}'), + tf.keras.layers.LeakyReLU(0.2), + tf.keras.layers.Conv2D(filters=channels, kernel_size=(1, 1), use_bias=True, name=f'blocks.{idx}.{num_layers + 2}') + ] + self.blocks.append(block) + self.shortcuts = [ + tf.keras.layers.Conv2D(channels, kernel_size=1, use_bias=True, name=f'shortcuts.{i}') + for i in range(num_res_blocks) + ] + + def call(self, x): + # breakpoint() + for block, shortcut in zip(self.blocks, self.shortcuts): + res = shortcut(x) + for layer in block: + x = layer(x) + x += res + return x \ No newline at end of file diff --git a/vocoder/tf/layers/pqmf.py b/vocoder/tf/layers/pqmf.py new file mode 100644 index 00000000..474b6e7f --- /dev/null +++ b/vocoder/tf/layers/pqmf.py @@ -0,0 +1,66 @@ +import numpy as np +import tensorflow as tf + +from scipy import signal as sig + + +class PQMF(tf.keras.layers.Layer): + def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): + super(PQMF, self).__init__() + # define filter coefficient + self.N = N + self.taps = taps + self.cutoff = cutoff + self.beta = beta + + QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta)) + H = np.zeros((N, len(QMF))) + G = np.zeros((N, len(QMF))) + for k in range(N): + constant_factor = (2 * k + 1) * (np.pi / + (2 * N)) * (np.arange(taps + 1) - + ((taps - 1) / 2)) + phase = (-1)**k * np.pi / 4 + H[k] = 2 * QMF * np.cos(constant_factor + phase) + + G[k] = 2 * QMF * np.cos(constant_factor - phase) + + # [N, 1, taps + 1] == [filter_width, in_channels, out_channels] + self.H = np.transpose(H[:, None, :], (2, 1, 0)).astype('float32') + self.G = np.transpose(G[None, :, :], (2, 1, 0)).astype('float32') + + # filter for downsampling & upsampling + updown_filter = np.zeros((N, N, N), dtype=np.float32) + for k in range(N): + updown_filter[0, k, k] = 1.0 + self.updown_filter = updown_filter.astype(np.float32) + + def analysis(self, x): + """ + x : B x 1 x T + """ + x = tf.transpose(x, perm=[0, 2, 1]) + x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]]) + x = tf.nn.conv1d(x, self.H, stride=1, padding='VALID') + x = tf.nn.conv1d(x, + self.updown_filter, + stride=self.N, + padding='VALID') + x = tf.transpose(x, perm=[0, 2, 1]) + return x + + def synthesis(self, x): + """ + x : B x 1 x T + """ + x = tf.transpose(x, perm=[0, 2, 1]) + x = tf.nn.conv1d_transpose( + x, + self.updown_filter * self.N, + strides=self.N, + output_shape=(tf.shape(x)[0], tf.shape(x)[1] * self.N, + self.N)) + x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]]) + x = tf.nn.conv1d(x, self.G, stride=1, padding="VALID") + x = tf.transpose(x, perm=[0, 2, 1]) + return x \ No newline at end of file diff --git a/vocoder/tf/models/melgan_generator.py b/vocoder/tf/models/melgan_generator.py new file mode 100644 index 00000000..db0a9675 --- /dev/null +++ b/vocoder/tf/models/melgan_generator.py @@ -0,0 +1,106 @@ +import logging +import os + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL +logging.getLogger('tensorflow').setLevel(logging.FATAL) + +import tensorflow as tf +from TTS.vocoder.tf.layers.melgan import ResidualStack, ReflectionPad1d + + +class MelganGenerator(tf.keras.models.Model): + """ Melgan Generator TF implementation dedicated for inference with no + weight norm """ + def __init__(self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(8, 8, 2, 2), + res_kernel=3, + num_res_blocks=3): + super(MelganGenerator, self).__init__() + + # assert model parameters + assert (proj_kernel - + 1) % 2 == 0, " [!] proj_kernel should be an odd number." + + # setup additional model parameters + base_padding = (proj_kernel - 1) // 2 + act_slope = 0.2 + self.inference_padding = 2 + + # initial layer + self.initial_layer = [ + ReflectionPad1d(base_padding), + tf.keras.layers.Conv2D(filters=base_channels, + kernel_size=(proj_kernel, 1), + strides=1, + padding='valid', + use_bias=True, + name="1") + ] + num_layers = 3 # count number of layers for layer naming + + # upsampling layers and residual stacks + self.upsample_layers = [] + for idx, upsample_factor in enumerate(upsample_factors): + layer_out_channels = base_channels // (2**(idx + 1)) + layer_filter_size = upsample_factor * 2 + layer_stride = upsample_factor + layer_output_padding = upsample_factor % 2 + self.upsample_layers += [ + tf.keras.layers.LeakyReLU(act_slope), + tf.keras.layers.Conv2DTranspose( + filters=layer_out_channels, + kernel_size=(layer_filter_size, 1), + strides=(layer_stride, 1), + padding='same', + # output_padding=layer_output_padding, + use_bias=True, + name=f'{num_layers}'), + ResidualStack( + channels=layer_out_channels, + num_res_blocks=num_res_blocks, + kernel_size=res_kernel, + name=f'layers.{num_layers + 1}' + ) + ] + num_layers += num_res_blocks - 1 + + self.upsample_layers += [tf.keras.layers.LeakyReLU(act_slope)] + + # final layer + self.final_layers = [ + ReflectionPad1d(base_padding), + tf.keras.layers.Conv2D(filters=out_channels, + kernel_size=(proj_kernel, 1), + use_bias=True, + name=f'layers.{num_layers + 1}'), + tf.keras.layers.Activation("tanh") + ] + + # self.initial_layer = tf.keras.models.Sequential(self.initial_layer) + # self.upsample_layers = tf.keras.models.Sequential(self.upsample_layers) + # self.final_layers = tf.keras.models.Sequential(self.final_layers) + # self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers") + self.model_layers = self.initial_layer + self.upsample_layers + self.final_layers + + def call(self, c, training=False): + """ + c : B x C x T + """ + if training: + raise NotImplementedError() + return self.inference(c) + + def inference(self, c): + c = tf.transpose(c, perm=[0, 2, 1]) + c = tf.expand_dims(c, 2) + c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") + o = c + for layer in self.model_layers: + o = layer(o) + # o = self.model_layers(c) + o = tf.transpose(o, perm=[0, 3, 2, 1]) + return o[:, :, 0, :] \ No newline at end of file diff --git a/vocoder/tf/models/multiband_melgan_generator.py b/vocoder/tf/models/multiband_melgan_generator.py new file mode 100644 index 00000000..e8599760 --- /dev/null +++ b/vocoder/tf/models/multiband_melgan_generator.py @@ -0,0 +1,46 @@ +import tensorflow as tf + +from TTS.vocoder.tf.models.melgan_generator import MelganGenerator +from TTS.vocoder.tf.layers.pqmf import PQMF + + +class MultibandMelganGenerator(MelganGenerator): + def __init__(self, + in_channels=80, + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=3): + super(MultibandMelganGenerator, + self).__init__(in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks) + self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) + + def pqmf_analysis(self, x): + return self.pqmf_layer.analysis(x) + + def pqmf_synthesis(self, x): + return self.pqmf_layer.synthesis(x) + + # def call(self, c, training=False): + # if training: + # raise NotImplementedError() + # return self.inference(c) + + def inference(self, c): + c = tf.transpose(c, perm=[0, 2, 1]) + c = tf.expand_dims(c, 2) + c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT") + o = c + for layer in self.model_layers: + o = layer(o) + o = tf.transpose(o, perm=[0, 3, 2, 1]) + o = self.pqmf_layer.synthesis(o[:, :, 0, :]) + return o diff --git a/vocoder/tf/utils/__init__.py b/vocoder/tf/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vocoder/tf/utils/convert_torch_to_tf_utils.py b/vocoder/tf/utils/convert_torch_to_tf_utils.py new file mode 100644 index 00000000..799235e3 --- /dev/null +++ b/vocoder/tf/utils/convert_torch_to_tf_utils.py @@ -0,0 +1,48 @@ +import numpy as np +import tensorflow as tf + + +def compare_torch_tf(torch_tensor, tf_tensor): + """ Compute the average absolute difference b/w torch and tf tensors """ + return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean() + + +def convert_tf_name(tf_name): + """ Convert certain patterns in TF layer names to Torch patterns """ + tf_name_tmp = tf_name + tf_name_tmp = tf_name_tmp.replace(':0', '') + tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0') + tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1') + tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh') + tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight') + tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight') + tf_name_tmp = tf_name_tmp.replace('/beta', '/bias') + tf_name_tmp = tf_name_tmp.replace('/', '.') + return tf_name_tmp + + +def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): + """ Transfer weigths from torch state_dict to TF variables """ + print(" > Passing weights from Torch to TF ...") + for tf_var in tf_vars: + torch_var_name = var_map_dict[tf_var.name] + print(f' | > {tf_var.name} <-- {torch_var_name}') + # if tuple, it is a bias variable + if 'kernel' in tf_var.name: + torch_weight = state_dict[torch_var_name] + try: + numpy_weight = torch_weight.permute([2, 1, 0]).numpy()[:, None, :, :] + except: + breakpoint() + if 'bias' in tf_var.name: + torch_weight = state_dict[torch_var_name] + numpy_weight = torch_weight + assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" + tf.keras.backend.set_value(tf_var, numpy_weight) + return tf_vars + + +def load_tf_vars(model_tf, tf_vars): + for tf_var in tf_vars: + model_tf.get_layer(tf_var.name).set_weights(tf_var) + return model_tf diff --git a/vocoder/tf/utils/generic_utils.py b/vocoder/tf/utils/generic_utils.py new file mode 100644 index 00000000..b17db596 --- /dev/null +++ b/vocoder/tf/utils/generic_utils.py @@ -0,0 +1,37 @@ +import re +import importlib +import numpy as np +from matplotlib import pyplot as plt + + +def to_camel(text): + text = text.capitalize() + return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + + +def setup_generator(c): + print(" > Generator Model: {}".format(c.generator_model)) + MyModel = importlib.import_module('TTS.vocoder.tf.models.' + + c.generator_model.lower()) + MyModel = getattr(MyModel, to_camel(c.generator_model)) + if c.generator_model in 'melgan_generator': + model = MyModel( + in_channels=c.audio['num_mels'], + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=c.generator_model_params['upsample_factors'], + res_kernel=3, + num_res_blocks=c.generator_model_params['num_res_blocks']) + if c.generator_model in 'melgan_fb_generator': + pass + if c.generator_model in 'multiband_melgan_generator': + model = MyModel( + in_channels=c.audio['num_mels'], + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=c.generator_model_params['upsample_factors'], + res_kernel=3, + num_res_blocks=c.generator_model_params['num_res_blocks']) + return model \ No newline at end of file diff --git a/vocoder/tf/utils/io.py b/vocoder/tf/utils/io.py new file mode 100644 index 00000000..d95d972c --- /dev/null +++ b/vocoder/tf/utils/io.py @@ -0,0 +1,27 @@ +import datetime +import pickle +import tensorflow as tf + + +def save_checkpoint(model, current_step, epoch, output_path, **kwargs): + """ Save TF Vocoder model """ + state = { + 'model': model.weights, + 'step': current_step, + 'epoch': epoch, + 'date': datetime.date.today().strftime("%B %d, %Y"), + } + state.update(kwargs) + pickle.dump(state, open(output_path, 'wb')) + + +def load_checkpoint(model, checkpoint_path): + """ Load TF Vocoder model """ + checkpoint = pickle.load(open(checkpoint_path, 'rb')) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + tf_vars = model.weights + for tf_var in tf_vars: + layer_name = tf_var.name + chkp_var_value = chkp_var_dict[layer_name] + tf.keras.backend.set_value(tf_var, chkp_var_value) + return model \ No newline at end of file