mirror of https://github.com/coqui-ai/TTS.git
linter fixes and test fixes
This commit is contained in:
parent
32d21545ac
commit
c990b3a59c
|
@ -2,26 +2,15 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import string
|
||||
import time
|
||||
from argparse import RawTextHelpFormatter
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from TTS.tts.utils.generic_utils import is_tacotron, setup_model
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.io import load_checkpoint
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.utils.manage import ModelManager
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input
|
||||
|
||||
|
||||
def str2bool(v):
|
||||
|
@ -29,17 +18,16 @@ def str2bool(v):
|
|||
return v
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
if v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# pylint: disable=bad-continuation
|
||||
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
|
||||
|
||||
'''You can either use your trained model or choose a model from the provided list.\n'''
|
||||
'''You can either use your trained model or choose a model from the provided list.\n'''\
|
||||
|
||||
'''
|
||||
Example runs:
|
||||
|
|
|
@ -224,7 +224,7 @@ class GlowTts(nn.Module):
|
|||
def store_inverse(self):
|
||||
self.decoder.store_inverse()
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
|
|
|
@ -188,10 +188,10 @@ class SpeedySpeech(nn.Module):
|
|||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
return o_de, attn
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
|
|
|
@ -121,7 +121,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
def inference(self):
|
||||
pass
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
self.decoder.set_r(state['r'])
|
||||
|
|
|
@ -50,7 +50,7 @@ def plot_spectrogram(spectrogram,
|
|||
spectrogram_ = spectrogram_.astype(
|
||||
np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
|
||||
if ap is not None:
|
||||
spectrogram_ = ap._denormalize(spectrogram_) # pylint: disable=protected-access
|
||||
spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
|
||||
fig = plt.figure(figsize=fig_size)
|
||||
plt.imshow(spectrogram_, aspect="auto", origin="lower")
|
||||
plt.colorbar()
|
||||
|
|
|
@ -17,7 +17,7 @@ from TTS.tts.utils.text import make_symbols, phonemes, symbols
|
|||
|
||||
|
||||
class Synthesizer(object):
|
||||
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint, vocoder_config, use_cuda):
|
||||
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False):
|
||||
"""Encapsulation of tts and vocoder models for inference.
|
||||
|
||||
TODO: handle multi-speaker and GST inference.
|
||||
|
@ -25,9 +25,9 @@ class Synthesizer(object):
|
|||
Args:
|
||||
tts_checkpoint (str): path to the tts model file.
|
||||
tts_config (str): path to the tts config file.
|
||||
vocoder_checkpoint (str): path to the vocoder model file.
|
||||
vocoder_config (str): path to the vocoder config file.
|
||||
use_cuda (bool): enable/disable cuda.
|
||||
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
||||
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
||||
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||
"""
|
||||
self.tts_checkpoint = tts_checkpoint
|
||||
self.tts_config = tts_config
|
||||
|
@ -38,6 +38,7 @@ class Synthesizer(object):
|
|||
self.vocoder_model = None
|
||||
self.num_speakers = 0
|
||||
self.tts_speakers = None
|
||||
self.speaker_embedding_dim = None
|
||||
self.seg = self.get_segmenter("en")
|
||||
self.use_cuda = use_cuda
|
||||
if self.use_cuda:
|
||||
|
@ -116,7 +117,7 @@ class Synthesizer(object):
|
|||
print(sens)
|
||||
|
||||
speaker_embedding = self.init_speaker(speaker_idx)
|
||||
use_gl = not hasattr(self, 'vocoder_model')
|
||||
use_gl = self.vocoder_model is None
|
||||
|
||||
for sen in sens:
|
||||
# synthesize voice
|
||||
|
@ -134,17 +135,17 @@ class Synthesizer(object):
|
|||
speaker_embedding=speaker_embedding)
|
||||
if not use_gl:
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.ap._denormalize(mel_postnet_spec.T).T
|
||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||
device_type = "cuda" if self.use_cuda else "cpu"
|
||||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap._normalize(mel_postnet_spec.T)
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
|
||||
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
else:
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0)
|
||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||
# run vocoder model
|
||||
# [1, T, C]
|
||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
||||
|
|
|
@ -96,7 +96,7 @@ class MelganGenerator(nn.Module):
|
|||
except ValueError:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
|
|
|
@ -158,7 +158,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
return self._get_receptive_field_size(self.layers, self.stacks,
|
||||
self.kernel_size)
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
|
|
|
@ -177,7 +177,7 @@ class Wavegrad(nn.Module):
|
|||
self.y_conv = weight_norm(self.y_conv)
|
||||
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
|
|
|
@ -500,7 +500,7 @@ class WaveRNN(nn.Module):
|
|||
|
||||
return unfolded
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
|
|
|
@ -20,7 +20,7 @@ def interpolate_vocoder_input(scale_factor, spec):
|
|||
torch.tensor: interpolated spectrogram.
|
||||
"""
|
||||
print(" > before interpolation :", spec.shape)
|
||||
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0)
|
||||
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
|
||||
spec = torch.nn.functional.interpolate(spec,
|
||||
scale_factor=scale_factor,
|
||||
recompute_scale_factor=True,
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
import unittest
|
||||
|
||||
from tests import get_tests_input_path, get_tests_output_path
|
||||
from TTS.server.synthesizer import Synthesizer
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
|
@ -29,7 +29,7 @@ class DemoServerTest(unittest.TestCase):
|
|||
tts_root_path = get_tests_output_path()
|
||||
config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint'])
|
||||
config['tts_config'] = os.path.join(tts_root_path, config['tts_config'])
|
||||
synthesizer = Synthesizer(config)
|
||||
synthesizer = Synthesizer(config['tts_checkpoint'], config['tts_config'], None, None)
|
||||
synthesizer.tts("Better this test works!!")
|
||||
|
||||
def test_split_into_sentences(self):
|
||||
|
|
Loading…
Reference in New Issue