linter fixes and test fixes

This commit is contained in:
Eren Gölge 2021-01-22 02:32:35 +01:00
parent 32d21545ac
commit c990b3a59c
12 changed files with 26 additions and 37 deletions

View File

@ -2,26 +2,15 @@
# -*- coding: utf-8 -*-
import argparse
import json
import os
import sys
import string
import time
from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path
import numpy as np
import torch
from TTS.tts.utils.generic_utils import is_tacotron, setup_model
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.io import load_checkpoint
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input
def str2bool(v):
@ -29,17 +18,16 @@ def str2bool(v):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
if v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
raise argparse.ArgumentTypeError('Boolean value expected.')
if __name__ == "__main__":
# pylint: disable=bad-continuation
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
'''You can either use your trained model or choose a model from the provided list.\n'''
'''You can either use your trained model or choose a model from the provided list.\n'''\
'''
Example runs:

View File

@ -224,7 +224,7 @@ class GlowTts(nn.Module):
def store_inverse(self):
self.decoder.store_inverse()
def load_checkpoint(self, config, checkpoint_path, eval=False):
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:

View File

@ -188,10 +188,10 @@ class SpeedySpeech(nn.Module):
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
return o_de, attn
def load_checkpoint(self, config, checkpoint_path, eval=False):
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:

View File

@ -121,7 +121,7 @@ class TacotronAbstract(ABC, nn.Module):
def inference(self):
pass
def load_checkpoint(self, config, checkpoint_path, eval=False):
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
self.decoder.set_r(state['r'])

View File

@ -50,7 +50,7 @@ def plot_spectrogram(spectrogram,
spectrogram_ = spectrogram_.astype(
np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
if ap is not None:
spectrogram_ = ap._denormalize(spectrogram_) # pylint: disable=protected-access
spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
fig = plt.figure(figsize=fig_size)
plt.imshow(spectrogram_, aspect="auto", origin="lower")
plt.colorbar()

View File

@ -17,7 +17,7 @@ from TTS.tts.utils.text import make_symbols, phonemes, symbols
class Synthesizer(object):
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint, vocoder_config, use_cuda):
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False):
"""Encapsulation of tts and vocoder models for inference.
TODO: handle multi-speaker and GST inference.
@ -25,9 +25,9 @@ class Synthesizer(object):
Args:
tts_checkpoint (str): path to the tts model file.
tts_config (str): path to the tts config file.
vocoder_checkpoint (str): path to the vocoder model file.
vocoder_config (str): path to the vocoder config file.
use_cuda (bool): enable/disable cuda.
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
"""
self.tts_checkpoint = tts_checkpoint
self.tts_config = tts_config
@ -38,6 +38,7 @@ class Synthesizer(object):
self.vocoder_model = None
self.num_speakers = 0
self.tts_speakers = None
self.speaker_embedding_dim = None
self.seg = self.get_segmenter("en")
self.use_cuda = use_cuda
if self.use_cuda:
@ -116,7 +117,7 @@ class Synthesizer(object):
print(sens)
speaker_embedding = self.init_speaker(speaker_idx)
use_gl = not hasattr(self, 'vocoder_model')
use_gl = self.vocoder_model is None
for sen in sens:
# synthesize voice
@ -134,17 +135,17 @@ class Synthesizer(object):
speaker_embedding=speaker_embedding)
if not use_gl:
# denormalize tts output based on tts audio config
mel_postnet_spec = self.ap._denormalize(mel_postnet_spec.T).T
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap._normalize(mel_postnet_spec.T)
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0)
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model
# [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))

View File

@ -96,7 +96,7 @@ class MelganGenerator(nn.Module):
except ValueError:
layer.remove_weight_norm()
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:

View File

@ -158,7 +158,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
return self._get_receptive_field_size(self.layers, self.stacks,
self.kernel_size)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:

View File

@ -177,7 +177,7 @@ class Wavegrad(nn.Module):
self.y_conv = weight_norm(self.y_conv)
def load_checkpoint(self, config, checkpoint_path, eval=False):
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:

View File

@ -500,7 +500,7 @@ class WaveRNN(nn.Module):
return unfolded
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:

View File

@ -20,7 +20,7 @@ def interpolate_vocoder_input(scale_factor, spec):
torch.tensor: interpolated spectrogram.
"""
print(" > before interpolation :", spec.shape)
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0)
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
spec = torch.nn.functional.interpolate(spec,
scale_factor=scale_factor,
recompute_scale_factor=True,

View File

@ -2,7 +2,7 @@ import os
import unittest
from tests import get_tests_input_path, get_tests_output_path
from TTS.server.synthesizer import Synthesizer
from TTS.utils.synthesizer import Synthesizer
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
@ -29,7 +29,7 @@ class DemoServerTest(unittest.TestCase):
tts_root_path = get_tests_output_path()
config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint'])
config['tts_config'] = os.path.join(tts_root_path, config['tts_config'])
synthesizer = Synthesizer(config)
synthesizer = Synthesizer(config['tts_checkpoint'], config['tts_config'], None, None)
synthesizer.tts("Better this test works!!")
def test_split_into_sentences(self):