linter fixes and test fixes

This commit is contained in:
Eren Gölge 2021-01-22 02:32:35 +01:00
parent 32d21545ac
commit c990b3a59c
12 changed files with 26 additions and 37 deletions

View File

@ -2,26 +2,15 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import argparse import argparse
import json
import os import os
import sys import sys
import string import string
import time
from argparse import RawTextHelpFormatter from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument # pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path from pathlib import Path
import numpy as np
import torch
from TTS.tts.utils.generic_utils import is_tacotron, setup_model
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.io import load_checkpoint
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
from TTS.utils.manage import ModelManager from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer
from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input
def str2bool(v): def str2bool(v):
@ -29,17 +18,16 @@ def str2bool(v):
return v return v
if v.lower() in ('yes', 'true', 't', 'y', '1'): if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'): if v.lower() in ('no', 'false', 'f', 'n', '0'):
return False return False
else: raise argparse.ArgumentTypeError('Boolean value expected.')
raise argparse.ArgumentTypeError('Boolean value expected.')
if __name__ == "__main__": if __name__ == "__main__":
# pylint: disable=bad-continuation
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n''' parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
'''You can either use your trained model or choose a model from the provided list.\n''' '''You can either use your trained model or choose a model from the provided list.\n'''\
''' '''
Example runs: Example runs:

View File

@ -224,7 +224,7 @@ class GlowTts(nn.Module):
def store_inverse(self): def store_inverse(self):
self.decoder.store_inverse() self.decoder.store_inverse()
def load_checkpoint(self, config, checkpoint_path, eval=False): def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model']) self.load_state_dict(state['model'])
if eval: if eval:

View File

@ -188,10 +188,10 @@ class SpeedySpeech(nn.Module):
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1) y_lengths = o_dr.sum(1)
o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
return o_de, attn return o_de, attn
def load_checkpoint(self, config, checkpoint_path, eval=False): def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model']) self.load_state_dict(state['model'])
if eval: if eval:

View File

@ -121,7 +121,7 @@ class TacotronAbstract(ABC, nn.Module):
def inference(self): def inference(self):
pass pass
def load_checkpoint(self, config, checkpoint_path, eval=False): def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model']) self.load_state_dict(state['model'])
self.decoder.set_r(state['r']) self.decoder.set_r(state['r'])

View File

@ -50,7 +50,7 @@ def plot_spectrogram(spectrogram,
spectrogram_ = spectrogram_.astype( spectrogram_ = spectrogram_.astype(
np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
if ap is not None: if ap is not None:
spectrogram_ = ap._denormalize(spectrogram_) # pylint: disable=protected-access spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
fig = plt.figure(figsize=fig_size) fig = plt.figure(figsize=fig_size)
plt.imshow(spectrogram_, aspect="auto", origin="lower") plt.imshow(spectrogram_, aspect="auto", origin="lower")
plt.colorbar() plt.colorbar()

View File

@ -17,7 +17,7 @@ from TTS.tts.utils.text import make_symbols, phonemes, symbols
class Synthesizer(object): class Synthesizer(object):
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint, vocoder_config, use_cuda): def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False):
"""Encapsulation of tts and vocoder models for inference. """Encapsulation of tts and vocoder models for inference.
TODO: handle multi-speaker and GST inference. TODO: handle multi-speaker and GST inference.
@ -25,9 +25,9 @@ class Synthesizer(object):
Args: Args:
tts_checkpoint (str): path to the tts model file. tts_checkpoint (str): path to the tts model file.
tts_config (str): path to the tts config file. tts_config (str): path to the tts config file.
vocoder_checkpoint (str): path to the vocoder model file. vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
vocoder_config (str): path to the vocoder config file. vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
use_cuda (bool): enable/disable cuda. use_cuda (bool, optional): enable/disable cuda. Defaults to False.
""" """
self.tts_checkpoint = tts_checkpoint self.tts_checkpoint = tts_checkpoint
self.tts_config = tts_config self.tts_config = tts_config
@ -38,6 +38,7 @@ class Synthesizer(object):
self.vocoder_model = None self.vocoder_model = None
self.num_speakers = 0 self.num_speakers = 0
self.tts_speakers = None self.tts_speakers = None
self.speaker_embedding_dim = None
self.seg = self.get_segmenter("en") self.seg = self.get_segmenter("en")
self.use_cuda = use_cuda self.use_cuda = use_cuda
if self.use_cuda: if self.use_cuda:
@ -116,7 +117,7 @@ class Synthesizer(object):
print(sens) print(sens)
speaker_embedding = self.init_speaker(speaker_idx) speaker_embedding = self.init_speaker(speaker_idx)
use_gl = not hasattr(self, 'vocoder_model') use_gl = self.vocoder_model is None
for sen in sens: for sen in sens:
# synthesize voice # synthesize voice
@ -134,17 +135,17 @@ class Synthesizer(object):
speaker_embedding=speaker_embedding) speaker_embedding=speaker_embedding)
if not use_gl: if not use_gl:
# denormalize tts output based on tts audio config # denormalize tts output based on tts audio config
mel_postnet_spec = self.ap._denormalize(mel_postnet_spec.T).T mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
device_type = "cuda" if self.use_cuda else "cpu" device_type = "cuda" if self.use_cuda else "cpu"
# renormalize spectrogram based on vocoder config # renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap._normalize(mel_postnet_spec.T) vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch # compute scale factor for possible sample rate mismatch
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate] scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
if scale_factor[1] != 1: if scale_factor[1] != 1:
print(" > interpolating tts model output.") print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
else: else:
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
# run vocoder model # run vocoder model
# [1, T, C] # [1, T, C]
waveform = self.vocoder_model.inference(vocoder_input.to(device_type)) waveform = self.vocoder_model.inference(vocoder_input.to(device_type))

View File

@ -96,7 +96,7 @@ class MelganGenerator(nn.Module):
except ValueError: except ValueError:
layer.remove_weight_norm() layer.remove_weight_norm()
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model']) self.load_state_dict(state['model'])
if eval: if eval:

View File

@ -158,7 +158,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
return self._get_receptive_field_size(self.layers, self.stacks, return self._get_receptive_field_size(self.layers, self.stacks,
self.kernel_size) self.kernel_size)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model']) self.load_state_dict(state['model'])
if eval: if eval:

View File

@ -177,7 +177,7 @@ class Wavegrad(nn.Module):
self.y_conv = weight_norm(self.y_conv) self.y_conv = weight_norm(self.y_conv)
def load_checkpoint(self, config, checkpoint_path, eval=False): def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model']) self.load_state_dict(state['model'])
if eval: if eval:

View File

@ -500,7 +500,7 @@ class WaveRNN(nn.Module):
return unfolded return unfolded
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu')) state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model']) self.load_state_dict(state['model'])
if eval: if eval:

View File

@ -20,7 +20,7 @@ def interpolate_vocoder_input(scale_factor, spec):
torch.tensor: interpolated spectrogram. torch.tensor: interpolated spectrogram.
""" """
print(" > before interpolation :", spec.shape) print(" > before interpolation :", spec.shape)
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
spec = torch.nn.functional.interpolate(spec, spec = torch.nn.functional.interpolate(spec,
scale_factor=scale_factor, scale_factor=scale_factor,
recompute_scale_factor=True, recompute_scale_factor=True,

View File

@ -2,7 +2,7 @@ import os
import unittest import unittest
from tests import get_tests_input_path, get_tests_output_path from tests import get_tests_input_path, get_tests_output_path
from TTS.server.synthesizer import Synthesizer from TTS.utils.synthesizer import Synthesizer
from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.io import save_checkpoint from TTS.tts.utils.io import save_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
@ -29,7 +29,7 @@ class DemoServerTest(unittest.TestCase):
tts_root_path = get_tests_output_path() tts_root_path = get_tests_output_path()
config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint']) config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint'])
config['tts_config'] = os.path.join(tts_root_path, config['tts_config']) config['tts_config'] = os.path.join(tts_root_path, config['tts_config'])
synthesizer = Synthesizer(config) synthesizer = Synthesizer(config['tts_checkpoint'], config['tts_config'], None, None)
synthesizer.tts("Better this test works!!") synthesizer.tts("Better this test works!!")
def test_split_into_sentences(self): def test_split_into_sentences(self):