diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 6cd28bc8..64d3298b 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -2,26 +2,15 @@ # -*- coding: utf-8 -*- import argparse -import json import os import sys import string -import time from argparse import RawTextHelpFormatter # pylint: disable=redefined-outer-name, unused-argument from pathlib import Path -import numpy as np -import torch -from TTS.tts.utils.generic_utils import is_tacotron, setup_model -from TTS.tts.utils.synthesis import synthesis -from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols -from TTS.tts.utils.io import load_checkpoint -from TTS.utils.audio import AudioProcessor -from TTS.utils.io import load_config from TTS.utils.manage import ModelManager from TTS.utils.synthesizer import Synthesizer -from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input def str2bool(v): @@ -29,17 +18,16 @@ def str2bool(v): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + if v.lower() in ('no', 'false', 'f', 'n', '0'): return False - else: - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError('Boolean value expected.') if __name__ == "__main__": - + # pylint: disable=bad-continuation parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n''' - '''You can either use your trained model or choose a model from the provided list.\n''' + '''You can either use your trained model or choose a model from the provided list.\n'''\ ''' Example runs: diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index c978e4fa..2f9b6f9b 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -224,7 +224,7 @@ class GlowTts(nn.Module): def store_inverse(self): self.decoder.store_inverse() - def load_checkpoint(self, config, checkpoint_path, eval=False): + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.load_state_dict(state['model']) if eval: diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 7f5c660e..93496d59 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -188,10 +188,10 @@ class SpeedySpeech(nn.Module): o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1) - o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) return o_de, attn - def load_checkpoint(self, config, checkpoint_path, eval=False): + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.load_state_dict(state['model']) if eval: diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 0a63b871..10953269 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -121,7 +121,7 @@ class TacotronAbstract(ABC, nn.Module): def inference(self): pass - def load_checkpoint(self, config, checkpoint_path, eval=False): + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.load_state_dict(state['model']) self.decoder.set_r(state['r']) diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index 17cba648..e5bb5891 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -50,7 +50,7 @@ def plot_spectrogram(spectrogram, spectrogram_ = spectrogram_.astype( np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ if ap is not None: - spectrogram_ = ap._denormalize(spectrogram_) # pylint: disable=protected-access + spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access fig = plt.figure(figsize=fig_size) plt.imshow(spectrogram_, aspect="auto", origin="lower") plt.colorbar() diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index f7ca5f44..615e0d1d 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -17,7 +17,7 @@ from TTS.tts.utils.text import make_symbols, phonemes, symbols class Synthesizer(object): - def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint, vocoder_config, use_cuda): + def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False): """Encapsulation of tts and vocoder models for inference. TODO: handle multi-speaker and GST inference. @@ -25,9 +25,9 @@ class Synthesizer(object): Args: tts_checkpoint (str): path to the tts model file. tts_config (str): path to the tts config file. - vocoder_checkpoint (str): path to the vocoder model file. - vocoder_config (str): path to the vocoder config file. - use_cuda (bool): enable/disable cuda. + vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None. + vocoder_config (str, optional): path to the vocoder config file. Defaults to None. + use_cuda (bool, optional): enable/disable cuda. Defaults to False. """ self.tts_checkpoint = tts_checkpoint self.tts_config = tts_config @@ -38,6 +38,7 @@ class Synthesizer(object): self.vocoder_model = None self.num_speakers = 0 self.tts_speakers = None + self.speaker_embedding_dim = None self.seg = self.get_segmenter("en") self.use_cuda = use_cuda if self.use_cuda: @@ -116,7 +117,7 @@ class Synthesizer(object): print(sens) speaker_embedding = self.init_speaker(speaker_idx) - use_gl = not hasattr(self, 'vocoder_model') + use_gl = self.vocoder_model is None for sen in sens: # synthesize voice @@ -134,17 +135,17 @@ class Synthesizer(object): speaker_embedding=speaker_embedding) if not use_gl: # denormalize tts output based on tts audio config - mel_postnet_spec = self.ap._denormalize(mel_postnet_spec.T).T + mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T device_type = "cuda" if self.use_cuda else "cpu" # renormalize spectrogram based on vocoder config - vocoder_input = self.vocoder_ap._normalize(mel_postnet_spec.T) + vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch - scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate] + scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate] if scale_factor[1] != 1: print(" > interpolating tts model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: - vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) + vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable # run vocoder model # [1, T, C] waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index e5fd46eb..3070eac7 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -96,7 +96,7 @@ class MelganGenerator(nn.Module): except ValueError: layer.remove_weight_norm() - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.load_state_dict(state['model']) if eval: diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index f5ed7712..1d1bcdcb 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -158,7 +158,7 @@ class ParallelWaveganGenerator(torch.nn.Module): return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.load_state_dict(state['model']) if eval: diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index bb9d04b8..f4a5faa3 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -177,7 +177,7 @@ class Wavegrad(nn.Module): self.y_conv = weight_norm(self.y_conv) - def load_checkpoint(self, config, checkpoint_path, eval=False): + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.load_state_dict(state['model']) if eval: diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index bded4cd8..cb03deb3 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -500,7 +500,7 @@ class WaveRNN(nn.Module): return unfolded - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.load_state_dict(state['model']) if eval: diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 478d9c00..fb943a37 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -20,7 +20,7 @@ def interpolate_vocoder_input(scale_factor, spec): torch.tensor: interpolated spectrogram. """ print(" > before interpolation :", spec.shape) - spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) + spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable spec = torch.nn.functional.interpolate(spec, scale_factor=scale_factor, recompute_scale_factor=True, diff --git a/tests/test_demo_server.py b/tests/test_demo_server.py index 0576430c..bccff55d 100644 --- a/tests/test_demo_server.py +++ b/tests/test_demo_server.py @@ -2,7 +2,7 @@ import os import unittest from tests import get_tests_input_path, get_tests_output_path -from TTS.server.synthesizer import Synthesizer +from TTS.utils.synthesizer import Synthesizer from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.io import save_checkpoint from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols @@ -29,7 +29,7 @@ class DemoServerTest(unittest.TestCase): tts_root_path = get_tests_output_path() config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint']) config['tts_config'] = os.path.join(tts_root_path, config['tts_config']) - synthesizer = Synthesizer(config) + synthesizer = Synthesizer(config['tts_checkpoint'], config['tts_config'], None, None) synthesizer.tts("Better this test works!!") def test_split_into_sentences(self):