mirror of https://github.com/coqui-ai/TTS.git
Issue #435 - Convert melgan vocoder models to TF2.0
This commit is contained in:
parent
58784ad09c
commit
6b2ff08239
|
@ -0,0 +1,115 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import torch
|
||||||
|
from fuzzywuzzy import fuzz
|
||||||
|
|
||||||
|
from TTS.utils.io import load_config
|
||||||
|
from TTS.vocoder.tf.models.multiband_melgan_generator import \
|
||||||
|
MultibandMelganGenerator
|
||||||
|
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
||||||
|
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
|
||||||
|
from TTS.vocoder.tf.utils.generic_utils import \
|
||||||
|
setup_generator as setup_tf_generator
|
||||||
|
from TTS.vocoder.tf.utils.io import save_checkpoint
|
||||||
|
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||||
|
|
||||||
|
# prevent GPU use
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||||
|
|
||||||
|
# define args
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--torch_model_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to target torch model to be converted to TF.')
|
||||||
|
parser.add_argument('--config_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to config file of torch model.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_path',
|
||||||
|
type=str,
|
||||||
|
help='path to output file including file name to save TF model.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# load model config
|
||||||
|
config_path = args.config_path
|
||||||
|
c = load_config(config_path)
|
||||||
|
num_speakers = 0
|
||||||
|
|
||||||
|
# init torch model
|
||||||
|
model = setup_generator(c)
|
||||||
|
checkpoint = torch.load(args.torch_model_path,
|
||||||
|
map_location=torch.device('cpu'))
|
||||||
|
state_dict = checkpoint['model']
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.remove_weight_norm()
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
# init tf model
|
||||||
|
model_tf = setup_tf_generator(c)
|
||||||
|
|
||||||
|
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
||||||
|
# get tf_model graph by passing an input
|
||||||
|
# B x D x T
|
||||||
|
dummy_input = tf.random.uniform((7, 80, 64))
|
||||||
|
mel_pred = model_tf(dummy_input, training=False)
|
||||||
|
|
||||||
|
# get tf variables
|
||||||
|
tf_vars = model_tf.weights
|
||||||
|
|
||||||
|
# match variable names with fuzzy logic
|
||||||
|
torch_var_names = list(state_dict.keys())
|
||||||
|
tf_var_names = [we.name for we in model_tf.weights]
|
||||||
|
for tf_name in tf_var_names:
|
||||||
|
# skip re-mapped layer names
|
||||||
|
if tf_name in [name[0] for name in var_map]:
|
||||||
|
continue
|
||||||
|
tf_name_edited = convert_tf_name(tf_name)
|
||||||
|
ratios = [
|
||||||
|
fuzz.ratio(torch_name, tf_name_edited)
|
||||||
|
for torch_name in torch_var_names
|
||||||
|
]
|
||||||
|
max_idx = np.argmax(ratios)
|
||||||
|
matching_name = torch_var_names[max_idx]
|
||||||
|
del torch_var_names[max_idx]
|
||||||
|
var_map.append((tf_name, matching_name))
|
||||||
|
|
||||||
|
# pass weights
|
||||||
|
tf_vars = transfer_weights_torch_to_tf(tf_vars, dict(var_map), state_dict)
|
||||||
|
|
||||||
|
# Compare TF and TORCH models
|
||||||
|
# check embedding outputs
|
||||||
|
model.eval()
|
||||||
|
dummy_input_torch = torch.ones((1, 80, 10))
|
||||||
|
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
||||||
|
dummy_input_tf = tf.transpose(dummy_input_tf, perm=[0, 2, 1])
|
||||||
|
dummy_input_tf = tf.expand_dims(dummy_input_tf, 2)
|
||||||
|
|
||||||
|
out_torch = model.layers[0](dummy_input_torch)
|
||||||
|
out_tf = model_tf.model_layers[0](dummy_input_tf)
|
||||||
|
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
|
||||||
|
|
||||||
|
assert compare_torch_tf(out_torch, out_tf_) < 1e-5
|
||||||
|
|
||||||
|
for i in range(1, len(model.layers)):
|
||||||
|
print(f"{i} -> {model.layers[i]} vs {model_tf.model_layers[i]}")
|
||||||
|
out_torch = model.layers[i](out_torch)
|
||||||
|
out_tf = model_tf.model_layers[i](out_tf)
|
||||||
|
out_tf_ = tf.transpose(out_tf, perm=[0, 3, 2, 1])[:, :, 0, :]
|
||||||
|
diff = compare_torch_tf(out_torch, out_tf_)
|
||||||
|
assert diff < 1e-5, diff
|
||||||
|
|
||||||
|
dummy_input_torch = torch.ones((1, 80, 10))
|
||||||
|
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
||||||
|
output_torch = model.inference(dummy_input_torch)
|
||||||
|
output_tf = model_tf(dummy_input_tf, training=False)
|
||||||
|
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
|
||||||
|
output_torch, output_tf)
|
||||||
|
# save tf model
|
||||||
|
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
|
||||||
|
args.output_path)
|
||||||
|
print(' > Model conversion is successfully completed :).')
|
|
@ -0,0 +1,52 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class ReflectionPad1d(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, padding):
|
||||||
|
super(ReflectionPad1d, self).__init__()
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
print(x.shape)
|
||||||
|
return tf.pad(x, [[0, 0], [self.padding, self.padding], [0, 0], [0, 0]], "REFLECT")
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualStack(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, channels, num_res_blocks, kernel_size, name):
|
||||||
|
super(ResidualStack, self).__init__(name=name)
|
||||||
|
|
||||||
|
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
|
||||||
|
base_padding = (kernel_size - 1) // 2
|
||||||
|
|
||||||
|
self.blocks = []
|
||||||
|
num_layers = 2
|
||||||
|
for idx in range(num_res_blocks):
|
||||||
|
layer_kernel_size = kernel_size
|
||||||
|
layer_dilation = layer_kernel_size**idx
|
||||||
|
layer_padding = base_padding * layer_dilation
|
||||||
|
block = [
|
||||||
|
tf.keras.layers.LeakyReLU(0.2),
|
||||||
|
ReflectionPad1d(layer_padding),
|
||||||
|
tf.keras.layers.Conv2D(filters=channels,
|
||||||
|
kernel_size=(kernel_size, 1),
|
||||||
|
dilation_rate=(layer_dilation, 1),
|
||||||
|
use_bias=True,
|
||||||
|
padding='valid',
|
||||||
|
name=f'blocks.{idx}.{num_layers}'),
|
||||||
|
tf.keras.layers.LeakyReLU(0.2),
|
||||||
|
tf.keras.layers.Conv2D(filters=channels, kernel_size=(1, 1), use_bias=True, name=f'blocks.{idx}.{num_layers + 2}')
|
||||||
|
]
|
||||||
|
self.blocks.append(block)
|
||||||
|
self.shortcuts = [
|
||||||
|
tf.keras.layers.Conv2D(channels, kernel_size=1, use_bias=True, name=f'shortcuts.{i}')
|
||||||
|
for i in range(num_res_blocks)
|
||||||
|
]
|
||||||
|
|
||||||
|
def call(self, x):
|
||||||
|
# breakpoint()
|
||||||
|
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||||
|
res = shortcut(x)
|
||||||
|
for layer in block:
|
||||||
|
x = layer(x)
|
||||||
|
x += res
|
||||||
|
return x
|
|
@ -0,0 +1,66 @@
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from scipy import signal as sig
|
||||||
|
|
||||||
|
|
||||||
|
class PQMF(tf.keras.layers.Layer):
|
||||||
|
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
|
||||||
|
super(PQMF, self).__init__()
|
||||||
|
# define filter coefficient
|
||||||
|
self.N = N
|
||||||
|
self.taps = taps
|
||||||
|
self.cutoff = cutoff
|
||||||
|
self.beta = beta
|
||||||
|
|
||||||
|
QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta))
|
||||||
|
H = np.zeros((N, len(QMF)))
|
||||||
|
G = np.zeros((N, len(QMF)))
|
||||||
|
for k in range(N):
|
||||||
|
constant_factor = (2 * k + 1) * (np.pi /
|
||||||
|
(2 * N)) * (np.arange(taps + 1) -
|
||||||
|
((taps - 1) / 2))
|
||||||
|
phase = (-1)**k * np.pi / 4
|
||||||
|
H[k] = 2 * QMF * np.cos(constant_factor + phase)
|
||||||
|
|
||||||
|
G[k] = 2 * QMF * np.cos(constant_factor - phase)
|
||||||
|
|
||||||
|
# [N, 1, taps + 1] == [filter_width, in_channels, out_channels]
|
||||||
|
self.H = np.transpose(H[:, None, :], (2, 1, 0)).astype('float32')
|
||||||
|
self.G = np.transpose(G[None, :, :], (2, 1, 0)).astype('float32')
|
||||||
|
|
||||||
|
# filter for downsampling & upsampling
|
||||||
|
updown_filter = np.zeros((N, N, N), dtype=np.float32)
|
||||||
|
for k in range(N):
|
||||||
|
updown_filter[0, k, k] = 1.0
|
||||||
|
self.updown_filter = updown_filter.astype(np.float32)
|
||||||
|
|
||||||
|
def analysis(self, x):
|
||||||
|
"""
|
||||||
|
x : B x 1 x T
|
||||||
|
"""
|
||||||
|
x = tf.transpose(x, perm=[0, 2, 1])
|
||||||
|
x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]])
|
||||||
|
x = tf.nn.conv1d(x, self.H, stride=1, padding='VALID')
|
||||||
|
x = tf.nn.conv1d(x,
|
||||||
|
self.updown_filter,
|
||||||
|
stride=self.N,
|
||||||
|
padding='VALID')
|
||||||
|
x = tf.transpose(x, perm=[0, 2, 1])
|
||||||
|
return x
|
||||||
|
|
||||||
|
def synthesis(self, x):
|
||||||
|
"""
|
||||||
|
x : B x 1 x T
|
||||||
|
"""
|
||||||
|
x = tf.transpose(x, perm=[0, 2, 1])
|
||||||
|
x = tf.nn.conv1d_transpose(
|
||||||
|
x,
|
||||||
|
self.updown_filter * self.N,
|
||||||
|
strides=self.N,
|
||||||
|
output_shape=(tf.shape(x)[0], tf.shape(x)[1] * self.N,
|
||||||
|
self.N))
|
||||||
|
x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]])
|
||||||
|
x = tf.nn.conv1d(x, self.G, stride=1, padding="VALID")
|
||||||
|
x = tf.transpose(x, perm=[0, 2, 1])
|
||||||
|
return x
|
|
@ -0,0 +1,106 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL
|
||||||
|
logging.getLogger('tensorflow').setLevel(logging.FATAL)
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from TTS.vocoder.tf.layers.melgan import ResidualStack, ReflectionPad1d
|
||||||
|
|
||||||
|
|
||||||
|
class MelganGenerator(tf.keras.models.Model):
|
||||||
|
""" Melgan Generator TF implementation dedicated for inference with no
|
||||||
|
weight norm """
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=80,
|
||||||
|
out_channels=1,
|
||||||
|
proj_kernel=7,
|
||||||
|
base_channels=512,
|
||||||
|
upsample_factors=(8, 8, 2, 2),
|
||||||
|
res_kernel=3,
|
||||||
|
num_res_blocks=3):
|
||||||
|
super(MelganGenerator, self).__init__()
|
||||||
|
|
||||||
|
# assert model parameters
|
||||||
|
assert (proj_kernel -
|
||||||
|
1) % 2 == 0, " [!] proj_kernel should be an odd number."
|
||||||
|
|
||||||
|
# setup additional model parameters
|
||||||
|
base_padding = (proj_kernel - 1) // 2
|
||||||
|
act_slope = 0.2
|
||||||
|
self.inference_padding = 2
|
||||||
|
|
||||||
|
# initial layer
|
||||||
|
self.initial_layer = [
|
||||||
|
ReflectionPad1d(base_padding),
|
||||||
|
tf.keras.layers.Conv2D(filters=base_channels,
|
||||||
|
kernel_size=(proj_kernel, 1),
|
||||||
|
strides=1,
|
||||||
|
padding='valid',
|
||||||
|
use_bias=True,
|
||||||
|
name="1")
|
||||||
|
]
|
||||||
|
num_layers = 3 # count number of layers for layer naming
|
||||||
|
|
||||||
|
# upsampling layers and residual stacks
|
||||||
|
self.upsample_layers = []
|
||||||
|
for idx, upsample_factor in enumerate(upsample_factors):
|
||||||
|
layer_out_channels = base_channels // (2**(idx + 1))
|
||||||
|
layer_filter_size = upsample_factor * 2
|
||||||
|
layer_stride = upsample_factor
|
||||||
|
layer_output_padding = upsample_factor % 2
|
||||||
|
self.upsample_layers += [
|
||||||
|
tf.keras.layers.LeakyReLU(act_slope),
|
||||||
|
tf.keras.layers.Conv2DTranspose(
|
||||||
|
filters=layer_out_channels,
|
||||||
|
kernel_size=(layer_filter_size, 1),
|
||||||
|
strides=(layer_stride, 1),
|
||||||
|
padding='same',
|
||||||
|
# output_padding=layer_output_padding,
|
||||||
|
use_bias=True,
|
||||||
|
name=f'{num_layers}'),
|
||||||
|
ResidualStack(
|
||||||
|
channels=layer_out_channels,
|
||||||
|
num_res_blocks=num_res_blocks,
|
||||||
|
kernel_size=res_kernel,
|
||||||
|
name=f'layers.{num_layers + 1}'
|
||||||
|
)
|
||||||
|
]
|
||||||
|
num_layers += num_res_blocks - 1
|
||||||
|
|
||||||
|
self.upsample_layers += [tf.keras.layers.LeakyReLU(act_slope)]
|
||||||
|
|
||||||
|
# final layer
|
||||||
|
self.final_layers = [
|
||||||
|
ReflectionPad1d(base_padding),
|
||||||
|
tf.keras.layers.Conv2D(filters=out_channels,
|
||||||
|
kernel_size=(proj_kernel, 1),
|
||||||
|
use_bias=True,
|
||||||
|
name=f'layers.{num_layers + 1}'),
|
||||||
|
tf.keras.layers.Activation("tanh")
|
||||||
|
]
|
||||||
|
|
||||||
|
# self.initial_layer = tf.keras.models.Sequential(self.initial_layer)
|
||||||
|
# self.upsample_layers = tf.keras.models.Sequential(self.upsample_layers)
|
||||||
|
# self.final_layers = tf.keras.models.Sequential(self.final_layers)
|
||||||
|
# self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers")
|
||||||
|
self.model_layers = self.initial_layer + self.upsample_layers + self.final_layers
|
||||||
|
|
||||||
|
def call(self, c, training=False):
|
||||||
|
"""
|
||||||
|
c : B x C x T
|
||||||
|
"""
|
||||||
|
if training:
|
||||||
|
raise NotImplementedError()
|
||||||
|
return self.inference(c)
|
||||||
|
|
||||||
|
def inference(self, c):
|
||||||
|
c = tf.transpose(c, perm=[0, 2, 1])
|
||||||
|
c = tf.expand_dims(c, 2)
|
||||||
|
c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
||||||
|
o = c
|
||||||
|
for layer in self.model_layers:
|
||||||
|
o = layer(o)
|
||||||
|
# o = self.model_layers(c)
|
||||||
|
o = tf.transpose(o, perm=[0, 3, 2, 1])
|
||||||
|
return o[:, :, 0, :]
|
|
@ -0,0 +1,46 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
|
||||||
|
from TTS.vocoder.tf.layers.pqmf import PQMF
|
||||||
|
|
||||||
|
|
||||||
|
class MultibandMelganGenerator(MelganGenerator):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=80,
|
||||||
|
out_channels=4,
|
||||||
|
proj_kernel=7,
|
||||||
|
base_channels=384,
|
||||||
|
upsample_factors=(2, 8, 2, 2),
|
||||||
|
res_kernel=3,
|
||||||
|
num_res_blocks=3):
|
||||||
|
super(MultibandMelganGenerator,
|
||||||
|
self).__init__(in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
proj_kernel=proj_kernel,
|
||||||
|
base_channels=base_channels,
|
||||||
|
upsample_factors=upsample_factors,
|
||||||
|
res_kernel=res_kernel,
|
||||||
|
num_res_blocks=num_res_blocks)
|
||||||
|
self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0)
|
||||||
|
|
||||||
|
def pqmf_analysis(self, x):
|
||||||
|
return self.pqmf_layer.analysis(x)
|
||||||
|
|
||||||
|
def pqmf_synthesis(self, x):
|
||||||
|
return self.pqmf_layer.synthesis(x)
|
||||||
|
|
||||||
|
# def call(self, c, training=False):
|
||||||
|
# if training:
|
||||||
|
# raise NotImplementedError()
|
||||||
|
# return self.inference(c)
|
||||||
|
|
||||||
|
def inference(self, c):
|
||||||
|
c = tf.transpose(c, perm=[0, 2, 1])
|
||||||
|
c = tf.expand_dims(c, 2)
|
||||||
|
c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
||||||
|
o = c
|
||||||
|
for layer in self.model_layers:
|
||||||
|
o = layer(o)
|
||||||
|
o = tf.transpose(o, perm=[0, 3, 2, 1])
|
||||||
|
o = self.pqmf_layer.synthesis(o[:, :, 0, :])
|
||||||
|
return o
|
|
@ -0,0 +1,48 @@
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def compare_torch_tf(torch_tensor, tf_tensor):
|
||||||
|
""" Compute the average absolute difference b/w torch and tf tensors """
|
||||||
|
return abs(torch_tensor.detach().numpy() - tf_tensor.numpy()).mean()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tf_name(tf_name):
|
||||||
|
""" Convert certain patterns in TF layer names to Torch patterns """
|
||||||
|
tf_name_tmp = tf_name
|
||||||
|
tf_name_tmp = tf_name_tmp.replace(':0', '')
|
||||||
|
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0')
|
||||||
|
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1')
|
||||||
|
tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh')
|
||||||
|
tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight')
|
||||||
|
tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight')
|
||||||
|
tf_name_tmp = tf_name_tmp.replace('/beta', '/bias')
|
||||||
|
tf_name_tmp = tf_name_tmp.replace('/', '.')
|
||||||
|
return tf_name_tmp
|
||||||
|
|
||||||
|
|
||||||
|
def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
|
||||||
|
""" Transfer weigths from torch state_dict to TF variables """
|
||||||
|
print(" > Passing weights from Torch to TF ...")
|
||||||
|
for tf_var in tf_vars:
|
||||||
|
torch_var_name = var_map_dict[tf_var.name]
|
||||||
|
print(f' | > {tf_var.name} <-- {torch_var_name}')
|
||||||
|
# if tuple, it is a bias variable
|
||||||
|
if 'kernel' in tf_var.name:
|
||||||
|
torch_weight = state_dict[torch_var_name]
|
||||||
|
try:
|
||||||
|
numpy_weight = torch_weight.permute([2, 1, 0]).numpy()[:, None, :, :]
|
||||||
|
except:
|
||||||
|
breakpoint()
|
||||||
|
if 'bias' in tf_var.name:
|
||||||
|
torch_weight = state_dict[torch_var_name]
|
||||||
|
numpy_weight = torch_weight
|
||||||
|
assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
|
||||||
|
tf.keras.backend.set_value(tf_var, numpy_weight)
|
||||||
|
return tf_vars
|
||||||
|
|
||||||
|
|
||||||
|
def load_tf_vars(model_tf, tf_vars):
|
||||||
|
for tf_var in tf_vars:
|
||||||
|
model_tf.get_layer(tf_var.name).set_weights(tf_var)
|
||||||
|
return model_tf
|
|
@ -0,0 +1,37 @@
|
||||||
|
import re
|
||||||
|
import importlib
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
|
def to_camel(text):
|
||||||
|
text = text.capitalize()
|
||||||
|
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_generator(c):
|
||||||
|
print(" > Generator Model: {}".format(c.generator_model))
|
||||||
|
MyModel = importlib.import_module('TTS.vocoder.tf.models.' +
|
||||||
|
c.generator_model.lower())
|
||||||
|
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||||
|
if c.generator_model in 'melgan_generator':
|
||||||
|
model = MyModel(
|
||||||
|
in_channels=c.audio['num_mels'],
|
||||||
|
out_channels=1,
|
||||||
|
proj_kernel=7,
|
||||||
|
base_channels=512,
|
||||||
|
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||||
|
res_kernel=3,
|
||||||
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
|
if c.generator_model in 'melgan_fb_generator':
|
||||||
|
pass
|
||||||
|
if c.generator_model in 'multiband_melgan_generator':
|
||||||
|
model = MyModel(
|
||||||
|
in_channels=c.audio['num_mels'],
|
||||||
|
out_channels=4,
|
||||||
|
proj_kernel=7,
|
||||||
|
base_channels=384,
|
||||||
|
upsample_factors=c.generator_model_params['upsample_factors'],
|
||||||
|
res_kernel=3,
|
||||||
|
num_res_blocks=c.generator_model_params['num_res_blocks'])
|
||||||
|
return model
|
|
@ -0,0 +1,27 @@
|
||||||
|
import datetime
|
||||||
|
import pickle
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(model, current_step, epoch, output_path, **kwargs):
|
||||||
|
""" Save TF Vocoder model """
|
||||||
|
state = {
|
||||||
|
'model': model.weights,
|
||||||
|
'step': current_step,
|
||||||
|
'epoch': epoch,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||||
|
}
|
||||||
|
state.update(kwargs)
|
||||||
|
pickle.dump(state, open(output_path, 'wb'))
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, checkpoint_path):
|
||||||
|
""" Load TF Vocoder model """
|
||||||
|
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
|
||||||
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
|
||||||
|
tf_vars = model.weights
|
||||||
|
for tf_var in tf_vars:
|
||||||
|
layer_name = tf_var.name
|
||||||
|
chkp_var_value = chkp_var_dict[layer_name]
|
||||||
|
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||||
|
return model
|
Loading…
Reference in New Issue