linter updates

This commit is contained in:
erogol 2020-05-18 18:46:13 +02:00
parent 496ff68dec
commit f75b0a6439
17 changed files with 177 additions and 163 deletions

View File

@ -9,7 +9,6 @@ import torch.distributed as dist
from torch.utils.data.sampler import Sampler from torch.utils.data.sampler import Sampler
from torch.autograd import Variable from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from TTS.utils.io import load_config
from TTS.utils.generic_utils import create_experiment_folder from TTS.utils.generic_utils import create_experiment_folder

View File

@ -11,9 +11,9 @@ class ConvBNBlock(nn.Module):
assert (kernel_size - 1) % 2 == 0 assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
self.convolution1d = nn.Conv1d(in_channels, self.convolution1d = nn.Conv1d(in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
padding=padding) padding=padding)
self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5) self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5)
self.dropout = nn.Dropout(p=0.5) self.dropout = nn.Dropout(p=0.5)
if activation == 'relu': if activation == 'relu':

View File

@ -171,12 +171,12 @@ class Synthesizer(object):
speaker_id = id_to_torch(speaker_id) speaker_id = id_to_torch(speaker_id)
if speaker_id is not None and self.use_cuda: if speaker_id is not None and self.use_cuda:
speaker_id = speaker_id.cuda() speaker_id = speaker_id.cuda()
for sen in sens: for sen in sens:
# preprocess the given text # preprocess the given text
inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda) inputs = text_to_seqvec(sen, self.tts_config, self.use_cuda)
# synthesize voice # synthesize voice
decoder_output, postnet_output, alignments, _ = run_model( decoder_output, postnet_output, alignments, _ = run_model_torch(
self.tts_model, inputs, self.tts_config, False, speaker_id, None) self.tts_model, inputs, self.tts_config, False, speaker_id, None)
# convert outputs to numpy # convert outputs to numpy
postnet_output, decoder_output, _ = parse_outputs( postnet_output, decoder_output, _ = parse_outputs(

View File

@ -25,7 +25,7 @@ def tts(model,
figures=False): figures=False):
t_1 = time.time() t_1 = time.time()
use_vocoder_model = vocoder_model is not None use_vocoder_model = vocoder_model is not None
waveform, alignment, _, postnet_output, stop_tokens = synthesis( waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis(
model, text, C, use_cuda, ap, speaker_id, style_wav=False, model, text, C, use_cuda, ap, speaker_id, style_wav=False,
truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars, truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars,
use_griffin_lim=(not use_vocoder_model), do_trim_silence=True) use_griffin_lim=(not use_vocoder_model), do_trim_silence=True)

View File

@ -1,14 +1,10 @@
import os import os
import copy
import torch import torch
import unittest import unittest
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from torch import optim
from torch import nn
from TTS.utils.io import load_config from TTS.utils.io import load_config
from TTS.layers.losses import MSELossMasked
from TTS.tf.models.tacotron2 import Tacotron2 from TTS.tf.models.tacotron2 import Tacotron2
#pylint: disable=unused-variable #pylint: disable=unused-variable
@ -22,36 +18,44 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
class TacotronTFTrainTest(unittest.TestCase): class TacotronTFTrainTest(unittest.TestCase):
def test_train_step(self):
''' test forward pass ''' @staticmethod
input = torch.randint(0, 24, (8, 128)).long().to(device) def generate_dummy_inputs():
input_lengths = torch.randint(100, 128, (8, )).long().to(device) chars_seq = torch.randint(0, 24, (8, 128)).long().to(device)
input_lengths = torch.sort(input_lengths, descending=True)[0] chars_seq_lengths = torch.randint(100, 128, (8, )).long().to(device)
chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0]
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
mel_lengths = torch.randint(20, 30, (8, )).long().to(device) mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
stop_targets = torch.zeros(8, 30, 1).float().to(device) stop_targets = torch.zeros(8, 30, 1).float().to(device)
speaker_ids = torch.randint(0, 5, (8, )).long().to(device) speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
input = tf.convert_to_tensor(input.cpu().numpy()) chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy())
input_lengths = tf.convert_to_tensor(input_lengths.cpu().numpy()) chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy())
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy()) mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\
stop_targets, speaker_ids
def test_train_step(self):
''' test forward pass '''
chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\
stop_targets, speaker_ids = self.generate_dummy_inputs()
for idx in mel_lengths: for idx in mel_lengths:
stop_targets[:, int(idx.item()):, 0] = 1.0 stop_targets[:, int(idx.item()):, 0] = 1.0
stop_targets = stop_targets.view(input.shape[0], stop_targets = stop_targets.view(chars_seq.shape[0],
stop_targets.size(1) // c.r, -1) stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5) model = Tacotron2(num_chars=24, r=c.r, num_speakers=5)
# training pass # training pass
output = model(input, input_lengths, mel_spec, training=True) output = model(chars_seq, chars_seq_lengths, mel_spec, training=True)
# check model output shapes # check model output shapes
assert np.all(output[0].shape == mel_spec.shape) assert np.all(output[0].shape == mel_spec.shape)
assert np.all(output[1].shape == mel_spec.shape) assert np.all(output[1].shape == mel_spec.shape)
assert output[2].shape[2] == input.shape[1] assert output[2].shape[2] == chars_seq.shape[1]
assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r) assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r)
assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r) assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)

View File

@ -10,27 +10,23 @@ import torch
import tensorflow as tf import tensorflow as tf
from fuzzywuzzy import fuzz from fuzzywuzzy import fuzz
from TTS.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.text.symbols import phonemes, symbols
from TTS.utils.generic_utils import setup_model, count_parameters from TTS.utils.generic_utils import setup_model
from TTS.utils.io import load_config from TTS.utils.io import load_config
from TTS_tf.models.tacotron2 import Tacotron2 from TTS_tf.models.tacotron2 import Tacotron2
from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name from TTS_tf.utils.convert_torch_to_tf_utils import compare_torch_tf, tf_create_dummy_inputs, transfer_weights_torch_to_tf, convert_tf_name
from TTS_tf.utils.generic_utils import save_checkpoint from TTS_tf.utils.generic_utils import save_checkpoint
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('--torch_model_path',
'--torch_model_path', type=str,
type=str, help='Path to target torch model to be converted to TF.')
help='Path to target torch model to be converted to TF.') parser.add_argument('--config_path',
parser.add_argument( type=str,
'--config_path', help='Path to config file of torch model.')
type=str, parser.add_argument('--output_path',
help='Path to config file of torch model.') type=str,
parser.add_argument( help='path to save TF model weights.')
'--output_path',
type=str,
help='path to save TF model weights.')
args = parser.parse_args() args = parser.parse_args()
# load model config # load model config
@ -41,7 +37,8 @@ num_speakers = 0
# init torch model # init torch model
num_chars = len(phonemes) if c.use_phonemes else len(symbols) num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c) model = setup_model(num_chars, num_speakers, c)
checkpoint = torch.load(args.torch_model_path, map_location=torch.device('cpu')) checkpoint = torch.load(args.torch_model_path,
map_location=torch.device('cpu'))
state_dict = checkpoint['model'] state_dict = checkpoint['model']
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
@ -69,18 +66,24 @@ model_tf = Tacotron2(num_chars=num_chars,
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
var_map = [ var_map = [
('tacotron2/embedding/embeddings:0', 'embedding.weight'), ('tacotron2/embedding/embeddings:0', 'embedding.weight'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', 'encoder.lstm.weight_ih_l0'), ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0'), 'encoder.lstm.weight_ih_l0'),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', 'encoder.lstm.weight_ih_l0_reverse'), ('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', 'encoder.lstm.weight_hh_l0_reverse'), 'encoder.lstm.weight_hh_l0'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0', ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')), ('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0', ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')), 'encoder.lstm.weight_ih_l0_reverse'),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
'encoder.lstm.weight_hh_l0_reverse'),
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'), ('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
('decoder/linear_projection/kernel:0', 'decoder.linear_projection.linear_layer.weight'), ('decoder/linear_projection/kernel:0',
'decoder.linear_projection.linear_layer.weight'),
('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight') ('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight')
] ]
# %% # %%
# get tf_model graph # get tf_model graph
input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs() input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()
@ -95,15 +98,17 @@ tf_var_names = [we.name for we in model_tf.weights]
for tf_name in tf_var_names: for tf_name in tf_var_names:
# skip re-mapped layer names # skip re-mapped layer names
if tf_name in [name[0] for name in var_map]: if tf_name in [name[0] for name in var_map]:
continue continue
tf_name_edited = convert_tf_name(tf_name) tf_name_edited = convert_tf_name(tf_name)
ratios = [fuzz.ratio(torch_name, tf_name_edited) for torch_name in torch_var_names] ratios = [
fuzz.ratio(torch_name, tf_name_edited)
for torch_name in torch_var_names
]
max_idx = np.argmax(ratios) max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx] matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx] del torch_var_names[max_idx]
var_map.append((tf_name, matching_name)) var_map.append((tf_name, matching_name))
# %% # %%
# print variable match # print variable match
from pprint import pprint from pprint import pprint
@ -121,20 +126,25 @@ input_ids = torch.randint(0, 24, (1, 128)).long()
o_t = model.embedding(input_ids) o_t = model.embedding(input_ids)
o_tf = model_tf.embedding(input_ids.detach().numpy()) o_tf = model_tf.embedding(input_ids.detach().numpy())
assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum() assert abs(o_t.detach().numpy() -
o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() -
o_tf.numpy()).sum()
# compare encoder outputs # compare encoder outputs
oo_en = model.encoder.inference(o_t.transpose(1,2)) oo_en = model.encoder.inference(o_t.transpose(1, 2))
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False) ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
assert compare_torch_tf(oo_en, ooo_en) < 1e-5 assert compare_torch_tf(oo_en, ooo_en) < 1e-5
#pylint: disable=redefined-builtin
# compare decoder.attention_rnn # compare decoder.attention_rnn
inp = torch.rand([1, 768]) inp = torch.rand([1, 768])
inp_tf = inp.numpy() inp_tf = inp.numpy()
model.decoder._init_states(oo_en, mask=None) model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
output, cell_state = model.decoder.attention_rnn(inp) output, cell_state = model.decoder.attention_rnn(inp)
states = model_tf.decoder.build_decoder_initial_states(1,512,128) states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False) output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf,
states[2],
training=False)
assert compare_torch_tf(output, output_tf).mean() < 1e-5 assert compare_torch_tf(output, output_tf).mean() < 1e-5
# compare decoder.attention # compare decoder.attention
@ -145,7 +155,8 @@ inputs_tf = inputs.numpy()
model.decoder.attention.init_states(inputs) model.decoder.attention.init_states(inputs)
processes_inputs = model.decoder.attention.preprocess_inputs(inputs) processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs) loc_attn, proc_query = model.decoder.attention.get_location_attention(
query, processes_inputs)
context = model.decoder.attention(query, inputs, processes_inputs, None) context = model.decoder.attention(query, inputs, processes_inputs, None)
model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf)) model_tf.decoder.attention.process_values(tf.convert_to_tensor(inputs_tf))
@ -159,10 +170,13 @@ assert compare_torch_tf(context, context_tf) < 1e-5
# compare decoder.decoder_rnn # compare decoder.decoder_rnn
input = torch.rand([1, 1536]) input = torch.rand([1, 1536])
input_tf = input.numpy() input_tf = input.numpy()
model.decoder._init_states(oo_en, mask=None) model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) output, cell_state = model.decoder.decoder_rnn(
states = model_tf.decoder.build_decoder_initial_states(1,512,128) input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False) states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf,
states[3],
training=False)
assert abs(input - input_tf).mean() < 1e-5 assert abs(input - input_tf).mean() < 1e-5
assert compare_torch_tf(output, output_tf).mean() < 1e-5 assert compare_torch_tf(output, output_tf).mean() < 1e-5
@ -177,15 +191,16 @@ assert compare_torch_tf(output, output_tf) < 1e-5
model.decoder.max_decoder_steps = 100 model.decoder.max_decoder_steps = 100
model_tf.decoder.set_max_decoder_steps(100) model_tf.decoder.set_max_decoder_steps(100)
output, align, stop = model.decoder.inference(oo_en) output, align, stop = model.decoder.inference(oo_en)
states = model_tf.decoder.build_decoder_initial_states(1,512,128) states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False) output_tf, align_tf, stop_tf = model_tf.decoder(ooo_en, states, training=False)
assert compare_torch_tf(output.transpose(1,2), output_tf) < 1e-4 assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4
# compare the whole model output # compare the whole model output
outputs_torch = model.inference(input_ids) outputs_torch = model.inference(input_ids)
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy())) outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean() ) print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5 assert compare_torch_tf(outputs_torch[2][:, 50, :],
outputs_tf[2][:, 50, :]) < 1e-5
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4 assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
# %% # %%
@ -193,4 +208,3 @@ assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'], save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'],
checkpoint['r'], args.output_path) checkpoint['r'], args.output_path)
print(' > Model conversion is successfully completed :).') print(' > Model conversion is successfully completed :).')

View File

@ -3,8 +3,6 @@ from tensorflow import keras
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
# from tensorflow_addons.seq2seq import BahdanauAttention # from tensorflow_addons.seq2seq import BahdanauAttention
from TTS.tf.utils.tf_utils import shape_list
class Linear(keras.layers.Layer): class Linear(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs): def __init__(self, units, use_bias, **kwargs):
@ -12,7 +10,7 @@ class Linear(keras.layers.Layer):
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
self.activation = keras.layers.ReLU() self.activation = keras.layers.ReLU()
def call(self, x, training=None): def call(self, x):
""" """
shapes: shapes:
x: B x T x C x: B x T x C
@ -77,9 +75,9 @@ def _sigmoid_norm(score):
class Attention(keras.layers.Layer): class Attention(keras.layers.Layer):
"""TODO: implement forward_attention""" """TODO: implement forward_attention
"""TODO: location sensitive attention""" TODO: location sensitive attention
"""TODO: implement attention windowing """ TODO: implement attention windowing """
def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
loc_attn_kernel_size, use_windowing, norm, use_forward_attn, loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
use_trans_agent, use_forward_attn_mask, **kwargs): use_trans_agent, use_forward_attn_mask, **kwargs):
@ -120,6 +118,7 @@ class Attention(keras.layers.Layer):
def process_values(self, values): def process_values(self, values):
""" cache values for decoder iterations """ """ cache values for decoder iterations """
#pylint: disable=attribute-defined-outside-init
self.processed_values = self.inputs_layer(values) self.processed_values = self.inputs_layer(values)
self.values = values self.values = values
@ -127,8 +126,7 @@ class Attention(keras.layers.Layer):
""" compute location attention, query layer and """ compute location attention, query layer and
unnorm. attention weights""" unnorm. attention weights"""
attention_cum, attention_old = states attention_cum, attention_old = states
attn_cat = tf.stack([attention_old, attention_cum], attn_cat = tf.stack([attention_old, attention_cum], axis=2)
axis=2)
processed_query = self.query_layer(tf.expand_dims(query, 1)) processed_query = self.query_layer(tf.expand_dims(query, 1))
processed_attn = self.location_dense(self.location_conv1d(attn_cat)) processed_attn = self.location_dense(self.location_conv1d(attn_cat))
@ -145,7 +143,7 @@ class Attention(keras.layers.Layer):
score = tf.squeeze(score, axis=2) score = tf.squeeze(score, axis=2)
return score, processed_query return score, processed_query
def apply_score_masking(self, score, mask): def apply_score_masking(self, score, mask): #pylint: disable=no-self-use
""" ignore sequence paddings """ """ ignore sequence paddings """
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
# Bias so padding positions do not contribute to attention distribution. # Bias so padding positions do not contribute to attention distribution.
@ -158,13 +156,13 @@ class Attention(keras.layers.Layer):
query: B x D query: B x D
""" """
if self.use_loc_attn: if self.use_loc_attn:
score, processed_query = self.get_loc_attn(query, states) score, _ = self.get_loc_attn(query, states)
else: else:
score, processed_query = self.get_attn(query) score, _ = self.get_attn(query)
# TODO: masking # TODO: masking
# if mask is not None: # if mask is not None:
# self.apply_score_masking(score, mask) # self.apply_score_masking(score, mask)
# attn_weights shape == (batch_size, max_length, 1) # attn_weights shape == (batch_size, max_length, 1)
attn_weights = self.norm_func(score) attn_weights = self.norm_func(score)

View File

@ -55,6 +55,7 @@ class Encoder(keras.layers.Layer):
class Decoder(keras.layers.Layer): class Decoder(keras.layers.Layer):
#pylint: disable=unused-argument
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs): use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs):
@ -135,7 +136,7 @@ class Decoder(keras.layers.Layer):
return output_frame, stopnet_output, states, attention return output_frame, stopnet_output, states, attention
def decode(self, memory, states, frames, memory_seq_length=None): def decode(self, memory, states, frames, memory_seq_length=None):
B, T, D = shape_list(memory) B, _, _ = shape_list(memory)
num_iter = shape_list(frames)[1] // self.r num_iter = shape_list(frames)[1] // self.r
# init states # init states
frame_zero = tf.expand_dims(states[0], 1) frame_zero = tf.expand_dims(states[0], 1)
@ -159,25 +160,25 @@ class Decoder(keras.layers.Layer):
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
_, memory, _, states, outputs, stop_tokens, attentions = \ _, memory, _, states, outputs, stop_tokens, attentions = \
tf.while_loop(lambda *arg: True, tf.while_loop(lambda *arg: True,
_body, _body,
loop_vars=(step_count, memory, prenet_output, states, outputs, loop_vars=(step_count, memory, prenet_output,
stop_tokens, attentions), states, outputs, stop_tokens, attentions),
parallel_iterations=32, parallel_iterations=32,
swap_memory=True, swap_memory=True,
maximum_iterations=num_iter) maximum_iterations=num_iter)
outputs = outputs.stack() outputs = outputs.stack()
attentions = attentions.stack() attentions = attentions.stack()
stop_tokens = stop_tokens.stack() stop_tokens = stop_tokens.stack()
outputs = tf.transpose(outputs, [1, 0, 2]) outputs = tf.transpose(outputs, [1, 0, 2])
attentions = tf.transpose(attentions, [1, 0 ,2]) attentions = tf.transpose(attentions, [1, 0, 2])
stop_tokens = tf.transpose(stop_tokens, [1, 0, 2]) stop_tokens = tf.transpose(stop_tokens, [1, 0, 2])
stop_tokens = tf.squeeze(stop_tokens, axis=2) stop_tokens = tf.squeeze(stop_tokens, axis=2)
outputs = tf.reshape(outputs, [B, -1, self.frame_dim]) outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
return outputs, stop_tokens, attentions return outputs, stop_tokens, attentions
def decode_inference(self, memory, states): def decode_inference(self, memory, states):
B, T, D = shape_list(memory) B, _, _ = shape_list(memory)
# init states # init states
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True) attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
@ -207,12 +208,12 @@ class Decoder(keras.layers.Layer):
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
_, memory, states, outputs, stop_tokens, attentions, stop_flag = \ _, memory, states, outputs, stop_tokens, attentions, stop_flag = \
tf.while_loop(cond, tf.while_loop(cond,
_body, _body,
loop_vars=(step_count, memory, states, outputs, loop_vars=(step_count, memory, states, outputs,
stop_tokens, attentions, stop_flag), stop_tokens, attentions, stop_flag),
parallel_iterations=32, parallel_iterations=32,
swap_memory=True, swap_memory=True,
maximum_iterations=self.max_decoder_steps) maximum_iterations=self.max_decoder_steps)
outputs = outputs.stack() outputs = outputs.stack()
attentions = attentions.stack() attentions = attentions.stack()

View File

@ -1,10 +1,10 @@
import tensorflow as tf
from tensorflow import keras from tensorflow import keras
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
from TTS.tf.utils.tf_utils import shape_list from TTS.tf.utils.tf_utils import shape_list
#pylint: disable=too-many-ancestors
class Tacotron2(keras.models.Model): class Tacotron2(keras.models.Model):
def __init__(self, def __init__(self,
num_chars, num_chars,
@ -35,16 +35,28 @@ class Tacotron2(keras.models.Model):
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding') self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
self.encoder = Encoder(512, name='encoder') self.encoder = Encoder(512, name='encoder')
# TODO: most of the decoder args have no use at the momment # TODO: most of the decoder args have no use at the momment
self.decoder = Decoder(decoder_output_dim, r, attn_type=attn_type, use_attn_win=attn_win, attn_norm=attn_norm, prenet_type=prenet_type, self.decoder = Decoder(decoder_output_dim,
prenet_dropout=prenet_dropout, use_forward_attn=forward_attn, use_trans_agent=trans_agent, use_forward_attn_mask=forward_attn_mask, r,
use_location_attn=location_attn, attn_K=attn_K, separate_stopnet=separate_stopnet, speaker_emb_dim=self.speaker_embed_dim) attn_type=attn_type,
use_attn_win=attn_win,
attn_norm=attn_norm,
prenet_type=prenet_type,
prenet_dropout=prenet_dropout,
use_forward_attn=forward_attn,
use_trans_agent=trans_agent,
use_forward_attn_mask=forward_attn_mask,
use_location_attn=location_attn,
attn_K=attn_K,
separate_stopnet=separate_stopnet,
speaker_emb_dim=self.speaker_embed_dim)
self.postnet = Postnet(postnet_output_dim, 5, name='postnet') self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
def call(self, characters, text_lengths=None, frames=None, training=None): def call(self, characters, text_lengths=None, frames=None, training=None):
if training == True: if training:
return self.training(characters, text_lengths, frames) return self.training(characters, text_lengths, frames)
else: if not training:
return self.inference(characters) return self.inference(characters)
raise RuntimeError(' [!] Set model training mode True or False')
def training(self, characters, text_lengths, frames): def training(self, characters, text_lengths, frames):
B, T = shape_list(characters) B, T = shape_list(characters)
@ -67,6 +79,3 @@ class Tacotron2(keras.models.Model):
print(output_frames.shape) print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens return decoder_frames, output_frames, attentions, stop_tokens

View File

@ -1,8 +1,5 @@
import numpy as np import numpy as np
import torch
import re
import tensorflow as tf import tensorflow as tf
import tensorflow.keras.backend as K
def tf_create_dummy_inputs(): def tf_create_dummy_inputs():
@ -17,7 +14,7 @@ def tf_create_dummy_inputs():
input_lengths[-1] = max_input_length input_lengths[-1] = max_input_length
input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32) input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32)
mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80]) mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80])
mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size]) mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size])
mel_lengths[-1] = max_mel_length mel_lengths[-1] = max_mel_length
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32) mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
return input_ids, input_lengths, mel_outputs, mel_lengths return input_ids, input_lengths, mel_outputs, mel_lengths
@ -49,7 +46,7 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
torch_var_name = var_map_dict[tf_var.name] torch_var_name = var_map_dict[tf_var.name]
print(f' | > {tf_var.name} <-- {torch_var_name}') print(f' | > {tf_var.name} <-- {torch_var_name}')
# if tuple, it is a bias variable # if tuple, it is a bias variable
if type(torch_var_name) is not tuple: if not isinstance(torch_var_name, tuple):
torch_layer_name = '.'.join(torch_var_name.split('.')[-2:]) torch_layer_name = '.'.join(torch_var_name.split('.')[-2:])
torch_weight = state_dict[torch_var_name] torch_weight = state_dict[torch_var_name]
if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name: if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name:

View File

@ -1,14 +1,8 @@
import os import os
import re
import glob
import shutil
import datetime import datetime
import json
import subprocess
import importlib import importlib
import pickle import pickle
import numpy as np import numpy as np
from collections import OrderedDict, Counter
import tensorflow as tf import tensorflow as tf
@ -29,7 +23,7 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, 'rb')) checkpoint = pickle.load(open(checkpoint_path, 'rb'))
chkp_var_dict = dict([(var.name, var.numpy()) for var in checkpoint['model']]) chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
tf_vars = model.weights tf_vars = model.weights
for tf_var in tf_vars: for tf_var in tf_vars:
layer_name = tf_var.name layer_name = tf_var.name
@ -64,7 +58,7 @@ def check_gradient(x, grad_clip):
def count_parameters(model, c): def count_parameters(model, c):
try: try:
return model.count_params() return model.count_params()
except: except RuntimeError:
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32')) input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32'))
input_lengths = np.random.randint(100, 129, (8, )) input_lengths = np.random.randint(100, 129, (8, ))
input_lengths[-1] = 128 input_lengths[-1] = 128
@ -74,7 +68,7 @@ def count_parameters(model, c):
mel_spec = tf.convert_to_tensor(mel_spec) mel_spec = tf.convert_to_tensor(mel_spec)
speaker_ids = np.random.randint( speaker_ids = np.random.randint(
0, 5, (8, )) if c.use_speaker_embedding else None 0, 5, (8, )) if c.use_speaker_embedding else None
_ = model(input_dummy, input_lengths, mel_spec) _ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids)
return model.count_params() return model.count_params()
@ -83,23 +77,23 @@ def setup_model(num_chars, num_speakers, c):
MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower()) MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower())
MyModel = getattr(MyModel, c.model) MyModel = getattr(MyModel, c.model)
if c.model.lower() in "tacotron": if c.model.lower() in "tacotron":
raise NotImplemented(' [!] Tacotron model is not ready.') raise NotImplementedError(' [!] Tacotron model is not ready.')
elif c.model.lower() == "tacotron2": # tacotron2
model = MyModel(num_chars=num_chars, model = MyModel(num_chars=num_chars,
num_speakers=num_speakers, num_speakers=num_speakers,
r=c.r, r=c.r,
postnet_output_dim=c.audio['num_mels'], postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'], decoder_output_dim=c.audio['num_mels'],
attn_type=c.attention_type, attn_type=c.attention_type,
attn_win=c.windowing, attn_win=c.windowing,
attn_norm=c.attention_norm, attn_norm=c.attention_norm,
prenet_type=c.prenet_type, prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout, prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn, forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent, trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask, forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn, location_attn=c.location_attn,
attn_K=c.attention_heads, attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet, separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder) bidirectional_decoder=c.bidirectional_decoder)
return model return model

View File

@ -190,7 +190,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
# backward pass # backward pass
loss_dict['loss'].backward() loss_dict['loss'].backward()
optimizer, current_lr = adam_weight_decay(optimizer) optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, grad_flag = check_update(model, c.grad_clip, ignore_stopnet=True) grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
optimizer.step() optimizer.step()
# compute alignment error (the lower the better ) # compute alignment error (the lower the better )
@ -232,8 +232,7 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss']
num_gpus) if c.stopnet else loss_dict['stopnet_loss']
if args.rank == 0: if args.rank == 0:
# Plot Training Iter Stats # Plot Training Iter Stats
@ -308,8 +307,6 @@ def train(model, criterion, optimizer, optimizer_st, scheduler,
@torch.no_grad() @torch.no_grad()
def evaluate(model, criterion, ap, global_step, epoch): def evaluate(model, criterion, ap, global_step, epoch):
data_loader = setup_loader(ap, model.decoder.r, is_val=True) data_loader = setup_loader(ap, model.decoder.r, is_val=True)
if c.use_speaker_embedding:
speaker_mapping = load_speaker_mapping(OUT_PATH)
model.eval() model.eval()
epoch_time = 0 epoch_time = 0
eval_values_dict = { eval_values_dict = {

View File

@ -6,6 +6,7 @@ import datetime
import subprocess import subprocess
import importlib import importlib
import numpy as np import numpy as np
from collections import Counter
def get_git_branch(): def get_git_branch():
@ -40,10 +41,10 @@ def get_commit_hash():
def create_experiment_folder(root_path, model_name, debug): def create_experiment_folder(root_path, model_name, debug):
""" Create a folder with the current date and time """ """ Create a folder with the current date and time """
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
# if debug: if debug:
# commit_hash = 'debug' commit_hash = 'debug'
# else: else:
commit_hash = get_commit_hash() commit_hash = get_commit_hash()
output_folder = os.path.join( output_folder = os.path.join(
root_path, model_name + '-' + date_str + '-' + commit_hash) root_path, model_name + '-' + date_str + '-' + commit_hash)
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
@ -87,8 +88,7 @@ def split_dataset(items):
items_eval.append(items[item_idx]) items_eval.append(items[item_idx])
del items[item_idx] del items[item_idx]
return items_eval, items return items_eval, items
else: return items[:eval_split_size], items[eval_split_size:]
return items[:eval_split_size], items[eval_split_size:]
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1

View File

@ -26,7 +26,7 @@ def copy_config_file(config_file, out_path, new_fields):
config_lines = open(config_file, "r").readlines() config_lines = open(config_file, "r").readlines()
# add extra information fields # add extra information fields
for key, value in new_fields.items(): for key, value in new_fields.items():
if type(value) == str: if isinstance(value, str):
new_line = '"{}":"{}",\n'.format(key, value) new_line = '"{}":"{}",\n'.format(key, value)
else: else:
new_line = '"{}":{},\n'.format(key, value) new_line = '"{}":{},\n'.format(key, value)
@ -37,7 +37,7 @@ def copy_config_file(config_file, out_path, new_fields):
def load_checkpoint(model, checkpoint_path, use_cuda=False): def load_checkpoint(model, checkpoint_path, use_cuda=False):
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(state['model']) model.load_state_dict(state['model'])
if use_cuda: if use_cuda:
model.cuda() model.cuda()
@ -55,7 +55,7 @@ def save_model(model, optimizer, current_step, epoch, r, output_path, **kwargs):
'step': current_step, 'step': current_step,
'epoch': epoch, 'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"), 'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r 'r': r
} }
state.update(kwargs) state.update(kwargs)
torch.save(state, output_path) torch.save(state, output_path)
@ -65,7 +65,7 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
file_name = 'checkpoint_{}.pth.tar'.format(current_step) file_name = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(output_folder, file_name) checkpoint_path = os.path.join(output_folder, file_name)
print(" > CHECKPOINT : {}".format(checkpoint_path)) print(" > CHECKPOINT : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch ,r, checkpoint_path, **kwargs) save_model(model, optimizer, current_step, epoch, r, checkpoint_path, **kwargs)
def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs): def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, **kwargs):
@ -73,6 +73,6 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, epoc
file_name = 'best_model.pth.tar' file_name = 'best_model.pth.tar'
checkpoint_path = os.path.join(output_folder, file_name) checkpoint_path = os.path.join(output_folder, file_name)
print(" > BEST MODEL : {}".format(checkpoint_path)) print(" > BEST MODEL : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch ,r, checkpoint_path, model_loss=target_loss) save_model(model, optimizer, current_step, epoch, r, checkpoint_path, model_loss=target_loss, **kwargs)
best_loss = target_loss best_loss = target_loss
return best_loss return best_loss

View File

@ -8,9 +8,9 @@ from torch.optim.optimizer import Optimizer, required
class RAdam(Optimizer): class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if not 0.0 <= lr: if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if eps < 0.0:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0: if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
@ -94,4 +94,4 @@ class RAdam(Optimizer):
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
p.data.copy_(p_data_fp32) p.data.copy_(p_data_fp32)
return loss return loss

View File

@ -1,5 +1,5 @@
import pkg_resources import pkg_resources
installed = {pkg.key for pkg in pkg_resources.working_set} installed = {pkg.key for pkg in pkg_resources.working_set} #pylint: disable=not-an-iterable
if 'tensorflow' in installed or 'tensorflow-gpu' in installed: if 'tensorflow' in installed or 'tensorflow-gpu' in installed:
import tensorflow as tf import tensorflow as tf
import torch import torch
@ -7,7 +7,7 @@ import numpy as np
from .text import text_to_sequence, phoneme_to_sequence from .text import text_to_sequence, phoneme_to_sequence
def text_to_seqvec(text, CONFIG, use_cuda): def text_to_seqvec(text, CONFIG):
text_cleaner = [CONFIG.text_cleaner] text_cleaner = [CONFIG.text_cleaner]
# text ot phonemes to sequence vector # text ot phonemes to sequence vector
if CONFIG.use_phonemes: if CONFIG.use_phonemes:
@ -37,7 +37,7 @@ def numpy_to_tf(np_array, dtype):
return tensor return tensor
def compute_style_mel(style_wav, ap, use_cuda): def compute_style_mel(style_wav, ap):
style_mel = ap.melspectrogram( style_mel = ap.melspectrogram(
ap.load_wav(style_wav)).expand_dims(0) ap.load_wav(style_wav)).expand_dims(0)
return style_mel return style_mel
@ -58,13 +58,13 @@ def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst: if CONFIG.use_gst and style_mel is not None:
raise NotImplemented(' [!] GST inference not implemented for TF') raise NotImplementedError(' [!] GST inference not implemented for TF')
if truncated: if truncated:
raise NotImplemented(' [!] Truncated inference not implemented for TF') raise NotImplementedError(' [!] Truncated inference not implemented for TF')
# TODO: handle multispeaker case # TODO: handle multispeaker case
decoder_output, postnet_output, alignments, stop_tokens = model( decoder_output, postnet_output, alignments, stop_tokens = model(
inputs) inputs, speaker_ids=speaker_id)
return decoder_output, postnet_output, alignments, stop_tokens return decoder_output, postnet_output, alignments, stop_tokens
@ -153,9 +153,9 @@ def synthesis(model,
# GST processing # GST processing
style_mel = None style_mel = None
if CONFIG.model == "TacotronGST" and style_wav is not None: if CONFIG.model == "TacotronGST" and style_wav is not None:
style_mel = compute_style_mel(style_wav, ap, use_cuda) style_mel = compute_style_mel(style_wav, ap)
# preprocess the given text # preprocess the given text
inputs = text_to_seqvec(text, CONFIG, use_cuda) inputs = text_to_seqvec(text, CONFIG)
# pass tensors to backend # pass tensors to backend
if backend == 'torch': if backend == 'torch':
speaker_id = id_to_torch(speaker_id) speaker_id = id_to_torch(speaker_id)

View File

@ -9,7 +9,7 @@ def check_update(model, grad_clip, ignore_stopnet=False):
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
else: else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if torch.isinf(grad_norm): if np.isinf(grad_norm):
print(" | > Gradient is INF !!") print(" | > Gradient is INF !!")
skip_flag = True skip_flag = True
return grad_norm, skip_flag return grad_norm, skip_flag
@ -62,6 +62,7 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn
}] }]
# pylint: disable=protected-access
class NoamLR(torch.optim.lr_scheduler._LRScheduler): class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps) self.warmup_steps = float(warmup_steps)
@ -87,4 +88,4 @@ def gradual_training_scheduler(global_step, config):
for values in config.gradual_training: for values in config.gradual_training:
if global_step * num_gpus >= values[0]: if global_step * num_gpus >= values[0]:
new_values = values new_values = values
return new_values[1], new_values[2] return new_values[1], new_values[2]