mirror of https://github.com/coqui-ai/TTS.git
Merge remote-tracking branch 'TTS/dev' into dev
This commit is contained in:
commit
fb3cd96893
|
@ -15,6 +15,7 @@ if [[ "$TEST_SUITE" == "unittest" ]]; then
|
||||||
nosetests TTS.speaker_encoder.tests --nocapture
|
nosetests TTS.speaker_encoder.tests --nocapture
|
||||||
nosetests TTS.vocoder.tests --nocapture
|
nosetests TTS.vocoder.tests --nocapture
|
||||||
nosetests TTS.tests --nocapture
|
nosetests TTS.tests --nocapture
|
||||||
|
nosetests TTS.tf.tests --nocapture
|
||||||
popd
|
popd
|
||||||
# Test server package
|
# Test server package
|
||||||
./tests/test_server_package.sh
|
./tests/test_server_package.sh
|
||||||
|
|
|
@ -29,6 +29,7 @@
|
||||||
"num_mels": 80, // size of the mel spec frame.
|
"num_mels": 80, // size of the mel spec frame.
|
||||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||||
|
"spec_gain": 20,
|
||||||
|
|
||||||
// Normalization parameters
|
// Normalization parameters
|
||||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||||
|
@ -86,7 +87,7 @@
|
||||||
"prenet_type": "bn", // "original" or "bn".
|
"prenet_type": "bn", // "original" or "bn".
|
||||||
"prenet_dropout": false, // enable/disable dropout at prenet.
|
"prenet_dropout": false, // enable/disable dropout at prenet.
|
||||||
|
|
||||||
// ATTENTION
|
// TACOTRON ATTENTION
|
||||||
"attention_type": "original", // 'original' or 'graves'
|
"attention_type": "original", // 'original' or 'graves'
|
||||||
"attention_heads": 4, // number of attention heads (only for 'graves')
|
"attention_heads": 4, // number of attention heads (only for 'graves')
|
||||||
"attention_norm": "sigmoid", // softmax or sigmoid.
|
"attention_norm": "sigmoid", // softmax or sigmoid.
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from math import sqrt
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
@ -65,10 +63,11 @@ class Tacotron2(TacotronAbstract):
|
||||||
self._init_backward_decoder()
|
self._init_backward_decoder()
|
||||||
# setup DDC
|
# setup DDC
|
||||||
if self.double_decoder_consistency:
|
if self.double_decoder_consistency:
|
||||||
self.coarse_decoder = Decoder(decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, attn_win,
|
self.coarse_decoder = Decoder(
|
||||||
attn_norm, prenet_type, prenet_dropout,
|
decoder_in_features, self.decoder_output_dim, ddc_r, attn_type,
|
||||||
forward_attn, trans_agent, forward_attn_mask,
|
attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||||
location_attn, attn_K, separate_stopnet, proj_speaker_dim)
|
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||||
|
separate_stopnet, proj_speaker_dim)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
||||||
|
|
|
@ -12,3 +12,4 @@ tqdm
|
||||||
soundfile
|
soundfile
|
||||||
phonemizer
|
phonemizer
|
||||||
bokeh==1.4.0
|
bokeh==1.4.0
|
||||||
|
inflect
|
||||||
|
|
|
@ -12,6 +12,7 @@ flask
|
||||||
scipy
|
scipy
|
||||||
tqdm
|
tqdm
|
||||||
soundfile
|
soundfile
|
||||||
|
inflect
|
||||||
phonemizer
|
phonemizer
|
||||||
bokeh==1.4.0
|
bokeh==1.4.0
|
||||||
nose
|
nose
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
"tts_file":"best_model.pth.tar", // tts checkpoint file
|
"tts_file":"best_model.pth.tar", // tts checkpoint file
|
||||||
"tts_config":"config.json", // tts config.json file
|
"tts_config":"config.json", // tts config.json file
|
||||||
"tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
|
"tts_speakers": null, // json file listing speaker ids. null if no speaker embedding.
|
||||||
|
"vocoder_config":null,
|
||||||
|
"vocoder_file": null,
|
||||||
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
|
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
|
||||||
"wavernn_path":null, // wavernn model root path
|
"wavernn_path":null, // wavernn model root path
|
||||||
"wavernn_file":null, // wavernn checkpoint file name
|
"wavernn_file":null, // wavernn checkpoint file name
|
||||||
|
|
|
@ -29,21 +29,18 @@ websites = r"[.](com|net|org|io|gov)"
|
||||||
class Synthesizer(object):
|
class Synthesizer(object):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self.wavernn = None
|
self.wavernn = None
|
||||||
self.pwgan = None
|
self.vocoder_model = None
|
||||||
self.config = config
|
self.config = config
|
||||||
self.use_cuda = self.config.use_cuda
|
self.use_cuda = self.config.use_cuda
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||||
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
|
self.load_tts(self.config.tts_checkpoint, self.config.tts_config,
|
||||||
self.config.use_cuda)
|
self.config.use_cuda)
|
||||||
if self.config.vocoder_checkpoint:
|
if self.config.vocoder_file:
|
||||||
self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda)
|
self.load_vocoder(self.config.vocoder_checkpoint, self.config.vocoder_config, self.config.use_cuda)
|
||||||
if self.config.wavernn_lib_path:
|
if self.config.wavernn_lib_path:
|
||||||
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
|
self.load_wavernn(self.config.wavernn_lib_path, self.config.wavernn_file,
|
||||||
self.config.wavernn_config, self.config.use_cuda)
|
self.config.wavernn_config, self.config.use_cuda)
|
||||||
if self.config.pwgan_file:
|
|
||||||
self.load_pwgan(self.config.pwgan_lib_path, self.config.pwgan_file,
|
|
||||||
self.config.pwgan_config, self.config.use_cuda)
|
|
||||||
|
|
||||||
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
||||||
# pylint: disable=global-statement
|
# pylint: disable=global-statement
|
||||||
|
@ -129,27 +126,6 @@ class Synthesizer(object):
|
||||||
self.wavernn.cuda()
|
self.wavernn.cuda()
|
||||||
self.wavernn.eval()
|
self.wavernn.eval()
|
||||||
|
|
||||||
def load_pwgan(self, lib_path, model_file, model_config, use_cuda):
|
|
||||||
if lib_path:
|
|
||||||
# set this if ParallelWaveGAN is not installed globally
|
|
||||||
sys.path.append(lib_path)
|
|
||||||
try:
|
|
||||||
#pylint: disable=import-outside-toplevel
|
|
||||||
from parallel_wavegan.models import ParallelWaveGANGenerator
|
|
||||||
except ImportError as e:
|
|
||||||
raise RuntimeError(f"cannot import parallel-wavegan, either install it or set its directory using the --pwgan_lib_path command line argument: {e}")
|
|
||||||
print(" > Loading PWGAN model ...")
|
|
||||||
print(" | > model config: ", model_config)
|
|
||||||
print(" | > model file: ", model_file)
|
|
||||||
with open(model_config) as f:
|
|
||||||
self.pwgan_config = yaml.load(f, Loader=yaml.Loader)
|
|
||||||
self.pwgan = ParallelWaveGANGenerator(**self.pwgan_config["generator_params"])
|
|
||||||
self.pwgan.load_state_dict(torch.load(model_file, map_location="cpu")["model"]["generator"])
|
|
||||||
self.pwgan.remove_weight_norm()
|
|
||||||
if use_cuda:
|
|
||||||
self.pwgan.cuda()
|
|
||||||
self.pwgan.eval()
|
|
||||||
|
|
||||||
def save_wav(self, wav, path):
|
def save_wav(self, wav, path):
|
||||||
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
# wav *= 32767 / max(1e-8, np.max(np.abs(wav)))
|
||||||
wav = np.array(wav)
|
wav = np.array(wav)
|
||||||
|
@ -202,9 +178,9 @@ class Synthesizer(object):
|
||||||
inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda)
|
inputs = numpy_to_torch(inputs, torch.long, cuda=self.use_cuda)
|
||||||
inputs = inputs.unsqueeze(0)
|
inputs = inputs.unsqueeze(0)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None)
|
_, postnet_output, _, _ = run_model_torch(self.tts_model, inputs, self.tts_config, False, speaker_id, None)
|
||||||
# convert outputs to numpy
|
|
||||||
if self.vocoder_model:
|
if self.vocoder_model:
|
||||||
|
# use native vocoder model
|
||||||
vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0)
|
vocoder_input = postnet_output[0].transpose(0, 1).unsqueeze(0)
|
||||||
wav = self.vocoder_model.inference(vocoder_input)
|
wav = self.vocoder_model.inference(vocoder_input)
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
|
@ -213,6 +189,7 @@ class Synthesizer(object):
|
||||||
wav = wav.numpy()
|
wav = wav.numpy()
|
||||||
wav = wav.flatten()
|
wav = wav.flatten()
|
||||||
elif self.wavernn:
|
elif self.wavernn:
|
||||||
|
# use 3rd paty wavernn
|
||||||
vocoder_input = None
|
vocoder_input = None
|
||||||
if self.tts_config.model == "Tacotron":
|
if self.tts_config.model == "Tacotron":
|
||||||
vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0)
|
vocoder_input = torch.FloatTensor(self.ap.out_linear_to_mel(linear_spec=postnet_output.T).T).T.unsqueeze(0)
|
||||||
|
@ -221,6 +198,15 @@ class Synthesizer(object):
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
vocoder_input.cuda()
|
vocoder_input.cuda()
|
||||||
wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
wav = self.wavernn.generate(vocoder_input, batched=self.config.is_wavernn_batched, target=11000, overlap=550)
|
||||||
|
else:
|
||||||
|
# use GL
|
||||||
|
if self.use_cuda:
|
||||||
|
postnet_output = postnet_output[0].cpu()
|
||||||
|
else:
|
||||||
|
postnet_output = postnet_output[0]
|
||||||
|
postnet_output = postnet_output.numpy()
|
||||||
|
wav = inv_spectrogram(postnet_output, self.ap, self.tts_config)
|
||||||
|
|
||||||
# trim silence
|
# trim silence
|
||||||
wav = trim_silence(wav, self.ap)
|
wav = trim_silence(wav, self.ap)
|
||||||
|
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -106,8 +106,8 @@ setup(
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"Pillow",
|
"Pillow",
|
||||||
"flask",
|
"flask",
|
||||||
# "lws",
|
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
"inflect",
|
||||||
"bokeh==1.4.0",
|
"bokeh==1.4.0",
|
||||||
"soundfile",
|
"soundfile",
|
||||||
"phonemizer @ https://github.com/bootphon/phonemizer/tarball/master",
|
"phonemizer @ https://github.com/bootphon/phonemizer/tarball/master",
|
||||||
|
|
|
@ -5,9 +5,8 @@
|
||||||
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
|
"wavernn_lib_path": null, // Rootpath to wavernn project folder to be imported. If this is null, model uses GL for speech synthesis.
|
||||||
"wavernn_file": null, // wavernn checkpoint file name
|
"wavernn_file": null, // wavernn checkpoint file name
|
||||||
"wavernn_config": null, // wavernn config file
|
"wavernn_config": null, // wavernn config file
|
||||||
"pwgan_lib_path": null,
|
"vocoder_config":null,
|
||||||
"pwgan_file": null,
|
"vocoder_file": null,
|
||||||
"pwgan_config": null,
|
|
||||||
"is_wavernn_batched":true,
|
"is_wavernn_batched":true,
|
||||||
"port": 5002,
|
"port": 5002,
|
||||||
"use_cuda": false,
|
"use_cuda": false,
|
||||||
|
|
|
@ -5,10 +5,10 @@
|
||||||
"audio":{
|
"audio":{
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
"num_mels": 80, // size of the mel spec frame.
|
"num_mels": 80, // size of the mel spec frame.
|
||||||
"num_freq": 1025, // number of stft frequency levels. Size of the linear spectogram frame.
|
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled.
|
||||||
"frame_length_ms": 50, // stft window length in ms.
|
"hop_length": 256,
|
||||||
"frame_shift_ms": 12.5, // stft window hop-lengh in ms.
|
"win_length": 1024,
|
||||||
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
"preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||||
"min_level_db": -100, // normalization range
|
"min_level_db": -100, // normalization range
|
||||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||||
|
|
|
@ -19,7 +19,8 @@
|
||||||
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||||
"mel_fmin": 0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
"mel_fmin": 0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||||
"mel_fmax": 8000, // maximum freq level for mel-spec. Tune for dataset!!
|
"mel_fmax": 8000, // maximum freq level for mel-spec. Tune for dataset!!
|
||||||
"do_trim_silence": false
|
"do_trim_silence": false,
|
||||||
|
"spec_gain": 20
|
||||||
},
|
},
|
||||||
|
|
||||||
"characters":{
|
"characters":{
|
||||||
|
|
|
@ -76,7 +76,7 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
# TODO: more assertion here
|
# TODO: more assertion here
|
||||||
assert type(speaker_name[0]) is str
|
assert type(speaker_name[0]) is str
|
||||||
assert linear_input.shape[0] == c.batch_size
|
assert linear_input.shape[0] == c.batch_size
|
||||||
assert linear_input.shape[2] == self.ap.num_freq
|
assert linear_input.shape[2] == self.ap.fft_size // 2 + 1
|
||||||
assert mel_input.shape[0] == c.batch_size
|
assert mel_input.shape[0] == c.batch_size
|
||||||
assert mel_input.shape[2] == c.audio['num_mels']
|
assert mel_input.shape[2] == c.audio['num_mels']
|
||||||
# check normalization ranges
|
# check normalization ranges
|
||||||
|
|
|
@ -28,6 +28,7 @@ class TacotronTrainTest(unittest.TestCase):
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
||||||
|
mel_lengths[0] = 30
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||||
|
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import unittest
|
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
from TTS.utils.io import load_config
|
|
||||||
from TTS.tf.models.tacotron2 import Tacotron2
|
|
||||||
|
|
||||||
#pylint: disable=unused-variable
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
|
||||||
use_cuda = torch.cuda.is_available()
|
|
||||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
|
||||||
|
|
||||||
|
|
||||||
class TacotronTFTrainTest(unittest.TestCase):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_dummy_inputs():
|
|
||||||
chars_seq = torch.randint(0, 24, (8, 128)).long().to(device)
|
|
||||||
chars_seq_lengths = torch.randint(100, 128, (8, )).long().to(device)
|
|
||||||
chars_seq_lengths = torch.sort(chars_seq_lengths, descending=True)[0]
|
|
||||||
mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
|
||||||
mel_postnet_spec = torch.rand(8, 30, c.audio['num_mels']).to(device)
|
|
||||||
mel_lengths = torch.randint(20, 30, (8, )).long().to(device)
|
|
||||||
stop_targets = torch.zeros(8, 30, 1).float().to(device)
|
|
||||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
|
||||||
|
|
||||||
chars_seq = tf.convert_to_tensor(chars_seq.cpu().numpy())
|
|
||||||
chars_seq_lengths = tf.convert_to_tensor(chars_seq_lengths.cpu().numpy())
|
|
||||||
mel_spec = tf.convert_to_tensor(mel_spec.cpu().numpy())
|
|
||||||
return chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\
|
|
||||||
stop_targets, speaker_ids
|
|
||||||
|
|
||||||
def test_train_step(self):
|
|
||||||
''' test forward pass '''
|
|
||||||
chars_seq, chars_seq_lengths, mel_spec, mel_postnet_spec, mel_lengths,\
|
|
||||||
stop_targets, speaker_ids = self.generate_dummy_inputs()
|
|
||||||
|
|
||||||
for idx in mel_lengths:
|
|
||||||
stop_targets[:, int(idx.item()):, 0] = 1.0
|
|
||||||
|
|
||||||
stop_targets = stop_targets.view(chars_seq.shape[0],
|
|
||||||
stop_targets.size(1) // c.r, -1)
|
|
||||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
|
||||||
|
|
||||||
model = Tacotron2(num_chars=24, r=c.r, num_speakers=5)
|
|
||||||
# training pass
|
|
||||||
output = model(chars_seq, chars_seq_lengths, mel_spec, training=True)
|
|
||||||
|
|
||||||
# check model output shapes
|
|
||||||
assert np.all(output[0].shape == mel_spec.shape)
|
|
||||||
assert np.all(output[1].shape == mel_spec.shape)
|
|
||||||
assert output[2].shape[2] == chars_seq.shape[1]
|
|
||||||
assert output[2].shape[1] == (mel_spec.shape[1] // model.decoder.r)
|
|
||||||
assert output[3].shape[1] == (mel_spec.shape[1] // model.decoder.r)
|
|
||||||
|
|
||||||
# inference pass
|
|
||||||
output = model(chars_seq, training=False)
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Convert Tensorflow Tacotron2 model to TF-Lite binary
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from TTS.utils.io import load_config
|
||||||
|
from TTS.utils.text.symbols import symbols, phonemes
|
||||||
|
from TTS.tf.utils.generic_utils import setup_model
|
||||||
|
from TTS.tf.utils.io import load_checkpoint
|
||||||
|
from TTS.tf.utils.tflite import convert_tacotron2_to_tflite
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--tf_model',
|
||||||
|
type=str,
|
||||||
|
help='Path to target torch model to be converted to TF.')
|
||||||
|
parser.add_argument('--config_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to config file of torch model.')
|
||||||
|
parser.add_argument('--output_path',
|
||||||
|
type=str,
|
||||||
|
help='path to tflite output binary.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set constants
|
||||||
|
CONFIG = load_config(args.config_path)
|
||||||
|
|
||||||
|
# load the model
|
||||||
|
c = CONFIG
|
||||||
|
num_speakers = 0
|
||||||
|
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||||
|
model = setup_model(num_chars, num_speakers, c, enable_tflite=True)
|
||||||
|
model.build_inference()
|
||||||
|
model = load_checkpoint(model, args.tf_model)
|
||||||
|
model.decoder.set_max_decoder_steps(1000)
|
||||||
|
|
||||||
|
# create tflite model
|
||||||
|
tflite_model = convert_tacotron2_to_tflite(model, output_path=args.output_path)
|
||||||
|
|
||||||
|
print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.')
|
|
@ -26,7 +26,7 @@ parser.add_argument('--config_path',
|
||||||
help='Path to config file of torch model.')
|
help='Path to config file of torch model.')
|
||||||
parser.add_argument('--output_path',
|
parser.add_argument('--output_path',
|
||||||
type=str,
|
type=str,
|
||||||
help='path to save TF model weights.')
|
help='path to output file including file name to save TF model.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# load model config
|
# load model config
|
||||||
|
@ -65,18 +65,18 @@ model_tf = Tacotron2(num_chars=num_chars,
|
||||||
# TODO: set layer names so that we can remove these manual matching
|
# TODO: set layer names so that we can remove these manual matching
|
||||||
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
||||||
var_map = [
|
var_map = [
|
||||||
('tacotron2/embedding/embeddings:0', 'embedding.weight'),
|
('embedding/embeddings:0', 'embedding.weight'),
|
||||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
|
('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
|
||||||
'encoder.lstm.weight_ih_l0'),
|
'encoder.lstm.weight_ih_l0'),
|
||||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
|
('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
|
||||||
'encoder.lstm.weight_hh_l0'),
|
'encoder.lstm.weight_hh_l0'),
|
||||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
|
('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
|
||||||
'encoder.lstm.weight_ih_l0_reverse'),
|
'encoder.lstm.weight_ih_l0_reverse'),
|
||||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
|
('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
|
||||||
'encoder.lstm.weight_hh_l0_reverse'),
|
'encoder.lstm.weight_hh_l0_reverse'),
|
||||||
('tacotron2/encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
|
('encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
|
||||||
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
|
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
|
||||||
('tacotron2/encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
|
('encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
|
||||||
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
|
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
|
||||||
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
|
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
|
||||||
('decoder/linear_projection/kernel:0',
|
('decoder/linear_projection/kernel:0',
|
||||||
|
@ -86,8 +86,7 @@ var_map = [
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# get tf_model graph
|
# get tf_model graph
|
||||||
input_ids, input_lengths, mel_outputs, mel_lengths = tf_create_dummy_inputs()
|
mel_pred = model_tf.build_inference()
|
||||||
mel_pred = model_tf(input_ids, training=False)
|
|
||||||
|
|
||||||
# get tf variables
|
# get tf variables
|
||||||
tf_vars = model_tf.weights
|
tf_vars = model_tf.weights
|
||||||
|
|
|
@ -109,12 +109,18 @@ class Attention(keras.layers.Layer):
|
||||||
raise ValueError("Unknown value for attention norm type")
|
raise ValueError("Unknown value for attention norm type")
|
||||||
|
|
||||||
def init_states(self, batch_size, value_length):
|
def init_states(self, batch_size, value_length):
|
||||||
states = ()
|
states = []
|
||||||
if self.use_loc_attn:
|
if self.use_loc_attn:
|
||||||
attention_cum = tf.zeros([batch_size, value_length])
|
attention_cum = tf.zeros([batch_size, value_length])
|
||||||
attention_old = tf.zeros([batch_size, value_length])
|
attention_old = tf.zeros([batch_size, value_length])
|
||||||
states = (attention_cum, attention_old)
|
states = [attention_cum, attention_old]
|
||||||
return states
|
if self.use_forward_attn:
|
||||||
|
alpha = tf.concat([
|
||||||
|
tf.ones([batch_size, 1]),
|
||||||
|
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7
|
||||||
|
], 1)
|
||||||
|
states.append(alpha)
|
||||||
|
return tuple(states)
|
||||||
|
|
||||||
def process_values(self, values):
|
def process_values(self, values):
|
||||||
""" cache values for decoder iterations """
|
""" cache values for decoder iterations """
|
||||||
|
@ -125,7 +131,7 @@ class Attention(keras.layers.Layer):
|
||||||
def get_loc_attn(self, query, states):
|
def get_loc_attn(self, query, states):
|
||||||
""" compute location attention, query layer and
|
""" compute location attention, query layer and
|
||||||
unnorm. attention weights"""
|
unnorm. attention weights"""
|
||||||
attention_cum, attention_old = states
|
attention_cum, attention_old = states[:2]
|
||||||
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
||||||
|
|
||||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||||
|
@ -150,6 +156,23 @@ class Attention(keras.layers.Layer):
|
||||||
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||||
return score
|
return score
|
||||||
|
|
||||||
|
def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use
|
||||||
|
# forward attention
|
||||||
|
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
|
||||||
|
# compute transition potentials
|
||||||
|
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
|
||||||
|
# renormalize attention weights
|
||||||
|
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
|
||||||
|
return new_alpha
|
||||||
|
|
||||||
|
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
|
||||||
|
states = []
|
||||||
|
if self.use_loc_attn:
|
||||||
|
states = [old_states[0] + scores_norm, attn_weights]
|
||||||
|
if self.use_forward_attn:
|
||||||
|
states.append(new_alpha)
|
||||||
|
return tuple(states)
|
||||||
|
|
||||||
def call(self, query, states):
|
def call(self, query, states):
|
||||||
"""
|
"""
|
||||||
shapes:
|
shapes:
|
||||||
|
@ -165,13 +188,19 @@ class Attention(keras.layers.Layer):
|
||||||
# self.apply_score_masking(score, mask)
|
# self.apply_score_masking(score, mask)
|
||||||
# attn_weights shape == (batch_size, max_length, 1)
|
# attn_weights shape == (batch_size, max_length, 1)
|
||||||
|
|
||||||
attn_weights = self.norm_func(score)
|
# normalize attention scores
|
||||||
|
scores_norm = self.norm_func(score)
|
||||||
|
attn_weights = scores_norm
|
||||||
|
|
||||||
# update attention states
|
# apply forward attention
|
||||||
if self.use_loc_attn:
|
new_alpha = None
|
||||||
states = (states[0] + attn_weights, attn_weights)
|
if self.use_forward_attn:
|
||||||
else:
|
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
|
||||||
states = ()
|
attn_weights = new_alpha
|
||||||
|
|
||||||
|
# update states tuple
|
||||||
|
# states = (cum_attn_weights, attn_weights, new_alpha)
|
||||||
|
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
|
||||||
|
|
||||||
# context_vector shape after sum == (batch_size, hidden_size)
|
# context_vector shape after sum == (batch_size, hidden_size)
|
||||||
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from TTS.tf.utils.tf_utils import shape_list
|
from TTS.tf.utils.tf_utils import shape_list
|
||||||
|
@ -58,12 +57,16 @@ class Decoder(keras.layers.Layer):
|
||||||
#pylint: disable=unused-argument
|
#pylint: disable=unused-argument
|
||||||
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
|
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
|
||||||
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
|
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
|
||||||
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, **kwargs):
|
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs):
|
||||||
super(Decoder, self).__init__(**kwargs)
|
super(Decoder, self).__init__(**kwargs)
|
||||||
self.frame_dim = frame_dim
|
self.frame_dim = frame_dim
|
||||||
self.r_init = tf.constant(r, dtype=tf.int32)
|
self.r_init = tf.constant(r, dtype=tf.int32)
|
||||||
self.r = tf.constant(r, dtype=tf.int32)
|
self.r = tf.constant(r, dtype=tf.int32)
|
||||||
|
self.output_dim = r * self.frame_dim
|
||||||
self.separate_stopnet = separate_stopnet
|
self.separate_stopnet = separate_stopnet
|
||||||
|
self.enable_tflite = enable_tflite
|
||||||
|
|
||||||
|
# layer constants
|
||||||
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
|
self.max_decoder_steps = tf.constant(1000, dtype=tf.int32)
|
||||||
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
|
self.stop_thresh = tf.constant(0.5, dtype=tf.float32)
|
||||||
|
|
||||||
|
@ -80,7 +83,7 @@ class Decoder(keras.layers.Layer):
|
||||||
[self.prenet_dim, self.prenet_dim],
|
[self.prenet_dim, self.prenet_dim],
|
||||||
bias=False,
|
bias=False,
|
||||||
name='prenet')
|
name='prenet')
|
||||||
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name=f'{self.name}/attention_rnn', )
|
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', )
|
||||||
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
||||||
|
|
||||||
# TODO: implement other attn options
|
# TODO: implement other attn options
|
||||||
|
@ -94,10 +97,10 @@ class Decoder(keras.layers.Layer):
|
||||||
use_trans_agent=use_trans_agent,
|
use_trans_agent=use_trans_agent,
|
||||||
use_forward_attn_mask=use_forward_attn_mask,
|
use_forward_attn_mask=use_forward_attn_mask,
|
||||||
name='attention')
|
name='attention')
|
||||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name=f'{self.name}/decoder_rnn')
|
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn')
|
||||||
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
||||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name=f'{self.name}/linear_projection/linear_layer')
|
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer')
|
||||||
self.stopnet = keras.layers.Dense(1, name=f'{self.name}/stopnet/linear_layer')
|
self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer')
|
||||||
|
|
||||||
|
|
||||||
def set_max_decoder_steps(self, new_max_steps):
|
def set_max_decoder_steps(self, new_max_steps):
|
||||||
|
@ -105,6 +108,7 @@ class Decoder(keras.layers.Layer):
|
||||||
|
|
||||||
def set_r(self, new_r):
|
def set_r(self, new_r):
|
||||||
self.r = tf.constant(new_r, dtype=tf.int32)
|
self.r = tf.constant(new_r, dtype=tf.int32)
|
||||||
|
self.output_dim = self.frame_dim * new_r
|
||||||
|
|
||||||
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
|
def build_decoder_initial_states(self, batch_size, memory_dim, memory_length):
|
||||||
zero_frame = tf.zeros([batch_size, self.frame_dim])
|
zero_frame = tf.zeros([batch_size, self.frame_dim])
|
||||||
|
@ -183,6 +187,7 @@ class Decoder(keras.layers.Layer):
|
||||||
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
outputs = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||||
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
attentions = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||||
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
stop_tokens = tf.TensorArray(dtype=tf.float32, size=0, clear_after_read=False, dynamic_size=True)
|
||||||
|
|
||||||
# pre-computes
|
# pre-computes
|
||||||
self.attention.process_values(memory)
|
self.attention.process_values(memory)
|
||||||
|
|
||||||
|
@ -226,7 +231,70 @@ class Decoder(keras.layers.Layer):
|
||||||
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
outputs = tf.reshape(outputs, [B, -1, self.frame_dim])
|
||||||
return outputs, stop_tokens, attentions
|
return outputs, stop_tokens, attentions
|
||||||
|
|
||||||
|
def decode_inference_tflite(self, memory, states):
|
||||||
|
"""Inference with TF-Lite compatibility. It assumes
|
||||||
|
batch_size is 1"""
|
||||||
|
# init states
|
||||||
|
# dynamic_shape is not supported in TFLite
|
||||||
|
outputs = tf.TensorArray(dtype=tf.float32,
|
||||||
|
size=self.max_decoder_steps,
|
||||||
|
element_shape=tf.TensorShape(
|
||||||
|
[self.output_dim]),
|
||||||
|
clear_after_read=False,
|
||||||
|
dynamic_size=False)
|
||||||
|
# stop_flags = tf.TensorArray(dtype=tf.bool,
|
||||||
|
# size=self.max_decoder_steps,
|
||||||
|
# element_shape=tf.TensorShape(
|
||||||
|
# []),
|
||||||
|
# clear_after_read=False,
|
||||||
|
# dynamic_size=False)
|
||||||
|
attentions = ()
|
||||||
|
stop_tokens = ()
|
||||||
|
|
||||||
|
# pre-computes
|
||||||
|
self.attention.process_values(memory)
|
||||||
|
|
||||||
|
# iter vars
|
||||||
|
stop_flag = tf.constant(False, dtype=tf.bool)
|
||||||
|
step_count = tf.constant(0, dtype=tf.int32)
|
||||||
|
|
||||||
|
def _body(step, memory, states, outputs, stop_flag):
|
||||||
|
frame_next = states[0]
|
||||||
|
prenet_next = self.prenet(frame_next, training=False)
|
||||||
|
output, stop_token, states, _ = self.step(prenet_next,
|
||||||
|
states,
|
||||||
|
None,
|
||||||
|
training=False)
|
||||||
|
stop_token = tf.math.sigmoid(stop_token)
|
||||||
|
stop_flag = tf.greater(stop_token, self.stop_thresh)
|
||||||
|
stop_flag = tf.reduce_all(stop_flag)
|
||||||
|
# stop_flags = stop_flags.write(step, tf.logical_not(stop_flag))
|
||||||
|
|
||||||
|
outputs = outputs.write(step, tf.reshape(output, [-1]))
|
||||||
|
return step + 1, memory, states, outputs, stop_flag
|
||||||
|
|
||||||
|
cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||||
|
step_count, memory, states, outputs, stop_flag = \
|
||||||
|
tf.while_loop(cond,
|
||||||
|
_body,
|
||||||
|
loop_vars=(step_count, memory, states, outputs,
|
||||||
|
stop_flag),
|
||||||
|
parallel_iterations=32,
|
||||||
|
swap_memory=True,
|
||||||
|
maximum_iterations=self.max_decoder_steps)
|
||||||
|
|
||||||
|
|
||||||
|
outputs = outputs.stack()
|
||||||
|
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
|
||||||
|
outputs = tf.expand_dims(outputs, axis=[0])
|
||||||
|
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||||
|
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
|
||||||
|
return outputs, stop_tokens, attentions
|
||||||
|
|
||||||
|
|
||||||
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
|
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
|
||||||
if training:
|
if training:
|
||||||
return self.decode(memory, states, frames, memory_seq_length)
|
return self.decode(memory, states, frames, memory_seq_length)
|
||||||
|
if self.enable_tflite:
|
||||||
|
return self.decode_inference_tflite(memory, states)
|
||||||
return self.decode_inference(memory, states)
|
return self.decode_inference(memory, states)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import tensorflow as tf
|
||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
|
|
||||||
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
|
from TTS.tf.layers.tacotron2 import Encoder, Decoder, Postnet
|
||||||
|
@ -23,7 +24,8 @@ class Tacotron2(keras.models.Model):
|
||||||
forward_attn_mask=False,
|
forward_attn_mask=False,
|
||||||
location_attn=True,
|
location_attn=True,
|
||||||
separate_stopnet=True,
|
separate_stopnet=True,
|
||||||
bidirectional_decoder=False):
|
bidirectional_decoder=False,
|
||||||
|
enable_tflite=False):
|
||||||
super(Tacotron2, self).__init__()
|
super(Tacotron2, self).__init__()
|
||||||
self.r = r
|
self.r = r
|
||||||
self.decoder_output_dim = decoder_output_dim
|
self.decoder_output_dim = decoder_output_dim
|
||||||
|
@ -31,6 +33,7 @@ class Tacotron2(keras.models.Model):
|
||||||
self.bidirectional_decoder = bidirectional_decoder
|
self.bidirectional_decoder = bidirectional_decoder
|
||||||
self.num_speakers = num_speakers
|
self.num_speakers = num_speakers
|
||||||
self.speaker_embed_dim = 256
|
self.speaker_embed_dim = 256
|
||||||
|
self.enable_tflite = enable_tflite
|
||||||
|
|
||||||
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
|
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
|
||||||
self.encoder = Encoder(512, name='encoder')
|
self.encoder = Encoder(512, name='encoder')
|
||||||
|
@ -48,9 +51,12 @@ class Tacotron2(keras.models.Model):
|
||||||
use_location_attn=location_attn,
|
use_location_attn=location_attn,
|
||||||
attn_K=attn_K,
|
attn_K=attn_K,
|
||||||
separate_stopnet=separate_stopnet,
|
separate_stopnet=separate_stopnet,
|
||||||
speaker_emb_dim=self.speaker_embed_dim)
|
speaker_emb_dim=self.speaker_embed_dim,
|
||||||
|
name='decoder',
|
||||||
|
enable_tflite=enable_tflite)
|
||||||
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
||||||
|
|
||||||
|
@tf.function(experimental_relax_shapes=True)
|
||||||
def call(self, characters, text_lengths=None, frames=None, training=None):
|
def call(self, characters, text_lengths=None, frames=None, training=None):
|
||||||
if training:
|
if training:
|
||||||
return self.training(characters, text_lengths, frames)
|
return self.training(characters, text_lengths, frames)
|
||||||
|
@ -79,3 +85,24 @@ class Tacotron2(keras.models.Model):
|
||||||
print(output_frames.shape)
|
print(output_frames.shape)
|
||||||
return decoder_frames, output_frames, attentions, stop_tokens
|
return decoder_frames, output_frames, attentions, stop_tokens
|
||||||
|
|
||||||
|
@tf.function(
|
||||||
|
experimental_relax_shapes=True,
|
||||||
|
input_signature=[
|
||||||
|
tf.TensorSpec([1, None], dtype=tf.int32),
|
||||||
|
],)
|
||||||
|
def inference_tflite(self, characters):
|
||||||
|
B, T = shape_list(characters)
|
||||||
|
embedding_vectors = self.embedding(characters, training=False)
|
||||||
|
encoder_output = self.encoder(embedding_vectors, training=False)
|
||||||
|
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
||||||
|
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, training=False)
|
||||||
|
postnet_frames = self.postnet(decoder_frames, training=False)
|
||||||
|
output_frames = decoder_frames + postnet_frames
|
||||||
|
print(output_frames.shape)
|
||||||
|
return decoder_frames, output_frames, attentions, stop_tokens
|
||||||
|
|
||||||
|
def build_inference(self, ):
|
||||||
|
# TODO: issue https://github.com/PyCQA/pylint/issues/3613
|
||||||
|
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) #pylint: disable=unexpected-keyword-arg
|
||||||
|
self(input_ids)
|
||||||
|
|
||||||
|
|
|
@ -6,9 +6,7 @@ import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **kwargs):
|
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||||
checkpoint_path = 'tts_tf_checkpoint_{}.pkl'.format(current_step)
|
|
||||||
checkpoint_path = os.path.join(output_folder, checkpoint_path)
|
|
||||||
state = {
|
state = {
|
||||||
'model': model.weights,
|
'model': model.weights,
|
||||||
'optimizer': optimizer,
|
'optimizer': optimizer,
|
||||||
|
@ -18,7 +16,7 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, **k
|
||||||
'r': r
|
'r': r
|
||||||
}
|
}
|
||||||
state.update(kwargs)
|
state.update(kwargs)
|
||||||
pickle.dump(state, open(checkpoint_path, 'wb'))
|
pickle.dump(state, open(output_path, 'wb'))
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(model, checkpoint_path):
|
def load_checkpoint(model, checkpoint_path):
|
||||||
|
@ -27,7 +25,13 @@ def load_checkpoint(model, checkpoint_path):
|
||||||
tf_vars = model.weights
|
tf_vars = model.weights
|
||||||
for tf_var in tf_vars:
|
for tf_var in tf_vars:
|
||||||
layer_name = tf_var.name
|
layer_name = tf_var.name
|
||||||
|
try:
|
||||||
chkp_var_value = chkp_var_dict[layer_name]
|
chkp_var_value = chkp_var_dict[layer_name]
|
||||||
|
except KeyError:
|
||||||
|
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
|
||||||
|
layer_name = f"{class_name}/{layer_name}"
|
||||||
|
chkp_var_value = chkp_var_dict[layer_name]
|
||||||
|
|
||||||
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||||
if 'r' in checkpoint.keys():
|
if 'r' in checkpoint.keys():
|
||||||
model.decoder.set_r(checkpoint['r'])
|
model.decoder.set_r(checkpoint['r'])
|
||||||
|
@ -72,7 +76,7 @@ def count_parameters(model, c):
|
||||||
return model.count_params()
|
return model.count_params()
|
||||||
|
|
||||||
|
|
||||||
def setup_model(num_chars, num_speakers, c):
|
def setup_model(num_chars, num_speakers, c, enable_tflite=False):
|
||||||
print(" > Using model: {}".format(c.model))
|
print(" > Using model: {}".format(c.model))
|
||||||
MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower())
|
MyModel = importlib.import_module('TTS.tf.models.' + c.model.lower())
|
||||||
MyModel = getattr(MyModel, c.model)
|
MyModel = getattr(MyModel, c.model)
|
||||||
|
@ -95,5 +99,6 @@ def setup_model(num_chars, num_speakers, c):
|
||||||
location_attn=c.location_attn,
|
location_attn=c.location_attn,
|
||||||
attn_K=c.attention_heads,
|
attn_K=c.attention_heads,
|
||||||
separate_stopnet=c.separate_stopnet,
|
separate_stopnet=c.separate_stopnet,
|
||||||
bidirectional_decoder=c.bidirectional_decoder)
|
bidirectional_decoder=c.bidirectional_decoder,
|
||||||
|
enable_tflite=enable_tflite)
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
import pickle
|
||||||
|
import datetime
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||||
|
state = {
|
||||||
|
'model': model.weights,
|
||||||
|
'optimizer': optimizer,
|
||||||
|
'step': current_step,
|
||||||
|
'epoch': epoch,
|
||||||
|
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||||
|
'r': r
|
||||||
|
}
|
||||||
|
state.update(kwargs)
|
||||||
|
pickle.dump(state, open(output_path, 'wb'))
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(model, checkpoint_path):
|
||||||
|
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
|
||||||
|
chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
|
||||||
|
tf_vars = model.weights
|
||||||
|
for tf_var in tf_vars:
|
||||||
|
layer_name = tf_var.name
|
||||||
|
try:
|
||||||
|
chkp_var_value = chkp_var_dict[layer_name]
|
||||||
|
except KeyError:
|
||||||
|
class_name = list(chkp_var_dict.keys())[0].split("/")[0]
|
||||||
|
layer_name = f"{class_name}/{layer_name}"
|
||||||
|
chkp_var_value = chkp_var_dict[layer_name]
|
||||||
|
|
||||||
|
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||||
|
if 'r' in checkpoint.keys():
|
||||||
|
model.decoder.set_r(checkpoint['r'])
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_tflite_model(tflite_path):
|
||||||
|
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
|
||||||
|
tflite_model.allocate_tensors()
|
||||||
|
return tflite_model
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tacotron2_to_tflite(model,
|
||||||
|
output_path=None,
|
||||||
|
experimental_converter=True):
|
||||||
|
"""Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is
|
||||||
|
provided, else return TFLite model."""
|
||||||
|
|
||||||
|
concrete_function = model.inference_tflite.get_concrete_function()
|
||||||
|
converter = tf.lite.TFLiteConverter.from_concrete_functions(
|
||||||
|
[concrete_function])
|
||||||
|
converter.experimental_new_converter = experimental_converter
|
||||||
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
converter.target_spec.supported_ops = [
|
||||||
|
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
|
||||||
|
]
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
if output_path is not None:
|
||||||
|
# same model binary if outputpath is provided
|
||||||
|
with open(output_path, 'wb') as f:
|
||||||
|
f.write(tflite_model)
|
||||||
|
return None
|
||||||
|
return tflite_model
|
||||||
|
|
||||||
|
|
||||||
|
def load_tflite_model(tflite_path):
|
||||||
|
tflite_model = tf.lite.Interpreter(model_path=tflite_path)
|
||||||
|
tflite_model.allocate_tensors()
|
||||||
|
return tflite_model
|
|
@ -68,6 +68,7 @@ class AudioProcessor(object):
|
||||||
self.hop_length = hop_length
|
self.hop_length = hop_length
|
||||||
self.win_length = win_length
|
self.win_length = win_length
|
||||||
assert min_level_db != 0.0, " [!] min_level_db is 0"
|
assert min_level_db != 0.0, " [!] min_level_db is 0"
|
||||||
|
assert self.win_length <= self.fft_size, " [!] win_length cannot be larger than fft_size"
|
||||||
members = vars(self)
|
members = vars(self)
|
||||||
for key, value in members.items():
|
for key, value in members.items():
|
||||||
print(" | > {}:{}".format(key, value))
|
print(" | > {}:{}".format(key, value))
|
||||||
|
|
|
@ -70,6 +70,31 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No
|
||||||
return decoder_output, postnet_output, alignments, stop_tokens
|
return decoder_output, postnet_output, alignments, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||||
|
if CONFIG.use_gst and style_mel is not None:
|
||||||
|
raise NotImplementedError(' [!] GST inference not implemented for TfLite')
|
||||||
|
if truncated:
|
||||||
|
raise NotImplementedError(' [!] Truncated inference not implemented for TfLite')
|
||||||
|
if speaker_id is not None:
|
||||||
|
raise NotImplementedError(' [!] Multi-Speaker not implemented for TfLite')
|
||||||
|
# get input and output details
|
||||||
|
input_details = model.get_input_details()
|
||||||
|
output_details = model.get_output_details()
|
||||||
|
# reshape input tensor for the new input shape
|
||||||
|
model.resize_tensor_input(input_details[0]['index'], inputs.shape)
|
||||||
|
model.allocate_tensors()
|
||||||
|
detail = input_details[0]
|
||||||
|
# input_shape = detail['shape']
|
||||||
|
model.set_tensor(detail['index'], inputs)
|
||||||
|
# run the model
|
||||||
|
model.invoke()
|
||||||
|
# collect outputs
|
||||||
|
decoder_output = model.get_tensor(output_details[0]['index'])
|
||||||
|
postnet_output = model.get_tensor(output_details[1]['index'])
|
||||||
|
# tflite model only returns feature frames
|
||||||
|
return decoder_output, postnet_output, None, None
|
||||||
|
|
||||||
|
|
||||||
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
||||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||||
decoder_output = decoder_output[0].data.cpu().numpy()
|
decoder_output = decoder_output[0].data.cpu().numpy()
|
||||||
|
@ -86,12 +111,18 @@ def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
||||||
return postnet_output, decoder_output, alignment, stop_tokens
|
return postnet_output, decoder_output, alignment, stop_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def parse_outputs_tflite(postnet_output, decoder_output):
|
||||||
|
postnet_output = postnet_output[0]
|
||||||
|
decoder_output = decoder_output[0]
|
||||||
|
return postnet_output, decoder_output
|
||||||
|
|
||||||
|
|
||||||
def trim_silence(wav, ap):
|
def trim_silence(wav, ap):
|
||||||
return wav[:ap.find_endpoint(wav)]
|
return wav[:ap.find_endpoint(wav)]
|
||||||
|
|
||||||
|
|
||||||
def inv_spectrogram(postnet_output, ap, CONFIG):
|
def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||||
if CONFIG.model in ["Tacotron", "TacotronGST"]:
|
if CONFIG.model.lower() in ["tacotron"]:
|
||||||
wav = ap.inv_spectrogram(postnet_output.T)
|
wav = ap.inv_spectrogram(postnet_output.T)
|
||||||
else:
|
else:
|
||||||
wav = ap.inv_melspectrogram(postnet_output.T)
|
wav = ap.inv_melspectrogram(postnet_output.T)
|
||||||
|
@ -164,22 +195,31 @@ def synthesis(model,
|
||||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||||
inputs = inputs.unsqueeze(0)
|
inputs = inputs.unsqueeze(0)
|
||||||
else:
|
elif backend == 'tf':
|
||||||
# TODO: handle speaker id for tf model
|
# TODO: handle speaker id for tf model
|
||||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||||
inputs = numpy_to_tf(inputs, tf.int32)
|
inputs = numpy_to_tf(inputs, tf.int32)
|
||||||
inputs = tf.expand_dims(inputs, 0)
|
inputs = tf.expand_dims(inputs, 0)
|
||||||
|
elif backend == 'tflite':
|
||||||
|
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||||
|
inputs = numpy_to_tf(inputs, tf.int32)
|
||||||
|
inputs = tf.expand_dims(inputs, 0)
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
if backend == 'torch':
|
if backend == 'torch':
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
||||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
||||||
postnet_output, decoder_output, alignments, stop_tokens)
|
postnet_output, decoder_output, alignments, stop_tokens)
|
||||||
else:
|
elif backend == 'tf':
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||||
postnet_output, decoder_output, alignments, stop_tokens)
|
postnet_output, decoder_output, alignments, stop_tokens)
|
||||||
|
elif backend == 'tflite':
|
||||||
|
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||||
|
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||||
|
postnet_output, decoder_output = parse_outputs_tflite(
|
||||||
|
postnet_output, decoder_output)
|
||||||
# convert outputs to numpy
|
# convert outputs to numpy
|
||||||
# plot results
|
# plot results
|
||||||
wav = None
|
wav = None
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
""" from https://github.com/keithito/tacotron """
|
""" from https://github.com/keithito/tacotron """
|
||||||
|
|
||||||
import inflect
|
import inflect
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
_inflect = inflect.engine()
|
_inflect = inflect.engine()
|
||||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||||
|
@ -49,16 +47,17 @@ def _expand_ordinal(m):
|
||||||
|
|
||||||
def _expand_number(m):
|
def _expand_number(m):
|
||||||
num = int(m.group(0))
|
num = int(m.group(0))
|
||||||
if num > 1000 and num < 3000:
|
if 1000 < num < 3000:
|
||||||
if num == 2000:
|
if num == 2000:
|
||||||
return 'two thousand'
|
return 'two thousand'
|
||||||
elif num > 2000 and num < 2010:
|
if 2000 < num < 2010:
|
||||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||||
elif num % 100 == 0:
|
if num % 100 == 0:
|
||||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||||
else:
|
return _inflect.number_to_words(num,
|
||||||
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
andword='',
|
||||||
else:
|
zero='oh',
|
||||||
|
group=2).replace(', ', ' ')
|
||||||
return _inflect.number_to_words(num, andword='')
|
return _inflect.number_to_words(num, andword='')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ ap = AudioProcessor(**C.audio)
|
||||||
|
|
||||||
|
|
||||||
def test_torch_stft():
|
def test_torch_stft():
|
||||||
torch_stft = TorchSTFT(ap.n_fft, ap.hop_length, ap.win_length)
|
torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length)
|
||||||
# librosa stft
|
# librosa stft
|
||||||
wav = ap.load_wav(WAV_FILE)
|
wav = ap.load_wav(WAV_FILE)
|
||||||
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access
|
||||||
|
@ -32,7 +32,7 @@ def test_torch_stft():
|
||||||
|
|
||||||
|
|
||||||
def test_stft_loss():
|
def test_stft_loss():
|
||||||
stft_loss = STFTLoss(ap.n_fft, ap.hop_length, ap.win_length)
|
stft_loss = STFTLoss(ap.fft_size, ap.hop_length, ap.win_length)
|
||||||
wav = ap.load_wav(WAV_FILE)
|
wav = ap.load_wav(WAV_FILE)
|
||||||
wav = torch.from_numpy(wav[None, :]).float()
|
wav = torch.from_numpy(wav[None, :]).float()
|
||||||
loss_m, loss_sc = stft_loss(wav, wav)
|
loss_m, loss_sc = stft_loss(wav, wav)
|
||||||
|
@ -43,7 +43,7 @@ def test_stft_loss():
|
||||||
|
|
||||||
|
|
||||||
def test_multiscale_stft_loss():
|
def test_multiscale_stft_loss():
|
||||||
stft_loss = MultiScaleSTFTLoss([ap.n_fft//2, ap.n_fft, ap.n_fft*2],
|
stft_loss = MultiScaleSTFTLoss([ap.fft_size//2, ap.fft_size, ap.fft_size*2],
|
||||||
[ap.hop_length // 2, ap.hop_length, ap.hop_length * 2],
|
[ap.hop_length // 2, ap.hop_length, ap.hop_length * 2],
|
||||||
[ap.win_length // 2, ap.win_length, ap.win_length * 2])
|
[ap.win_length // 2, ap.win_length, ap.win_length * 2])
|
||||||
wav = ap.load_wav(WAV_FILE)
|
wav = ap.load_wav(WAV_FILE)
|
||||||
|
|
|
@ -100,13 +100,18 @@ for i in range(1, len(model.layers)):
|
||||||
diff = compare_torch_tf(out_torch, out_tf_)
|
diff = compare_torch_tf(out_torch, out_tf_)
|
||||||
assert diff < 1e-5, diff
|
assert diff < 1e-5, diff
|
||||||
|
|
||||||
dummy_input_torch = torch.ones((1, 80, 10))
|
torch.manual_seed(0)
|
||||||
|
dummy_input_torch = torch.rand((1, 80, 100))
|
||||||
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
dummy_input_tf = tf.convert_to_tensor(dummy_input_torch.numpy())
|
||||||
|
model.inference_padding = 0
|
||||||
|
model_tf.inference_padding = 0
|
||||||
output_torch = model.inference(dummy_input_torch)
|
output_torch = model.inference(dummy_input_torch)
|
||||||
output_tf = model_tf(dummy_input_tf, training=False)
|
output_tf = model_tf(dummy_input_tf, training=False)
|
||||||
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
|
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
|
||||||
output_torch, output_tf)
|
output_torch, output_tf)
|
||||||
|
|
||||||
# save tf model
|
# save tf model
|
||||||
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
|
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
|
||||||
args.output_path)
|
args.output_path)
|
||||||
print(' > Model conversion is successfully completed :).')
|
print(' > Model conversion is successfully completed :).')
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ class ReflectionPad1d(tf.keras.layers.Layer):
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
|
|
||||||
def call(self, x):
|
def call(self, x):
|
||||||
print(x.shape)
|
|
||||||
return tf.pad(x, [[0, 0], [self.padding, self.padding], [0, 0], [0, 0]], "REFLECT")
|
return tf.pad(x, [[0, 0], [self.padding, self.padding], [0, 0], [0, 0]], "REFLECT")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,9 @@ import tensorflow as tf
|
||||||
from TTS.vocoder.tf.layers.melgan import ResidualStack, ReflectionPad1d
|
from TTS.vocoder.tf.layers.melgan import ResidualStack, ReflectionPad1d
|
||||||
|
|
||||||
|
|
||||||
class MelganGenerator(tf.keras.models.Model): # pylint: disable=too-many-ancestors
|
#pylint: disable=too-many-ancestors
|
||||||
|
#pylint: disable=abstract-method
|
||||||
|
class MelganGenerator(tf.keras.models.Model):
|
||||||
""" Melgan Generator TF implementation dedicated for inference with no
|
""" Melgan Generator TF implementation dedicated for inference with no
|
||||||
weight norm """
|
weight norm """
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -83,6 +85,7 @@ class MelganGenerator(tf.keras.models.Model): # pylint: disable=too-many-ancest
|
||||||
# self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers")
|
# self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers")
|
||||||
self.model_layers = self.initial_layer + self.upsample_layers + self.final_layers
|
self.model_layers = self.initial_layer + self.upsample_layers + self.final_layers
|
||||||
|
|
||||||
|
@tf.function(experimental_relax_shapes=True)
|
||||||
def call(self, c, training=False):
|
def call(self, c, training=False):
|
||||||
"""
|
"""
|
||||||
c : B x C x T
|
c : B x C x T
|
||||||
|
@ -94,10 +97,15 @@ class MelganGenerator(tf.keras.models.Model): # pylint: disable=too-many-ancest
|
||||||
def inference(self, c):
|
def inference(self, c):
|
||||||
c = tf.transpose(c, perm=[0, 2, 1])
|
c = tf.transpose(c, perm=[0, 2, 1])
|
||||||
c = tf.expand_dims(c, 2)
|
c = tf.expand_dims(c, 2)
|
||||||
c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
# FIXME: TF had no replicate padding as in Torch
|
||||||
|
# c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
||||||
o = c
|
o = c
|
||||||
for layer in self.model_layers:
|
for layer in self.model_layers:
|
||||||
o = layer(o)
|
o = layer(o)
|
||||||
# o = self.model_layers(c)
|
# o = self.model_layers(c)
|
||||||
o = tf.transpose(o, perm=[0, 3, 2, 1])
|
o = tf.transpose(o, perm=[0, 3, 2, 1])
|
||||||
return o[:, :, 0, :]
|
return o[:, :, 0, :]
|
||||||
|
|
||||||
|
def build_inference(self):
|
||||||
|
x = tf.random.uniform((1, self.in_channels, 4), dtype=tf.float32)
|
||||||
|
self(x, training=False)
|
|
@ -3,8 +3,9 @@ import tensorflow as tf
|
||||||
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
|
from TTS.vocoder.tf.models.melgan_generator import MelganGenerator
|
||||||
from TTS.vocoder.tf.layers.pqmf import PQMF
|
from TTS.vocoder.tf.layers.pqmf import PQMF
|
||||||
|
|
||||||
|
#pylint: disable=too-many-ancestors
|
||||||
class MultibandMelganGenerator(MelganGenerator): # pylint: disable=too-many-ancestors
|
#pylint: disable=abstract-method
|
||||||
|
class MultibandMelganGenerator(MelganGenerator):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels=80,
|
in_channels=80,
|
||||||
out_channels=4,
|
out_channels=4,
|
||||||
|
@ -37,7 +38,8 @@ class MultibandMelganGenerator(MelganGenerator): # pylint: disable=too-many-anc
|
||||||
def inference(self, c):
|
def inference(self, c):
|
||||||
c = tf.transpose(c, perm=[0, 2, 1])
|
c = tf.transpose(c, perm=[0, 2, 1])
|
||||||
c = tf.expand_dims(c, 2)
|
c = tf.expand_dims(c, 2)
|
||||||
c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
# FIXME: TF had no replicate padding as in Torch
|
||||||
|
# c = tf.pad(c, [[0, 0], [self.inference_padding, self.inference_padding], [0, 0], [0, 0]], "REFLECT")
|
||||||
o = c
|
o = c
|
||||||
for layer in self.model_layers:
|
for layer in self.model_layers:
|
||||||
o = layer(o)
|
o = layer(o)
|
||||||
|
|
Loading…
Reference in New Issue