mirror of https://github.com/coqui-ai/TTS.git
linter fixes and test fixes
This commit is contained in:
parent
32d21545ac
commit
c990b3a59c
|
@ -2,26 +2,15 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import string
|
import string
|
||||||
import time
|
|
||||||
from argparse import RawTextHelpFormatter
|
from argparse import RawTextHelpFormatter
|
||||||
# pylint: disable=redefined-outer-name, unused-argument
|
# pylint: disable=redefined-outer-name, unused-argument
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from TTS.tts.utils.generic_utils import is_tacotron, setup_model
|
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
|
||||||
from TTS.tts.utils.io import load_checkpoint
|
|
||||||
from TTS.utils.audio import AudioProcessor
|
|
||||||
from TTS.utils.io import load_config
|
|
||||||
from TTS.utils.manage import ModelManager
|
from TTS.utils.manage import ModelManager
|
||||||
from TTS.utils.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input
|
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
|
@ -29,17 +18,16 @@ def str2bool(v):
|
||||||
return v
|
return v
|
||||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||||
return True
|
return True
|
||||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
if v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||||
return False
|
return False
|
||||||
else:
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# pylint: disable=bad-continuation
|
||||||
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
|
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
|
||||||
|
|
||||||
'''You can either use your trained model or choose a model from the provided list.\n'''
|
'''You can either use your trained model or choose a model from the provided list.\n'''\
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Example runs:
|
Example runs:
|
||||||
|
|
|
@ -224,7 +224,7 @@ class GlowTts(nn.Module):
|
||||||
def store_inverse(self):
|
def store_inverse(self):
|
||||||
self.decoder.store_inverse()
|
self.decoder.store_inverse()
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
self.load_state_dict(state['model'])
|
self.load_state_dict(state['model'])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -188,10 +188,10 @@ class SpeedySpeech(nn.Module):
|
||||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||||
y_lengths = o_dr.sum(1)
|
y_lengths = o_dr.sum(1)
|
||||||
o_de, attn= self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||||
return o_de, attn
|
return o_de, attn
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
self.load_state_dict(state['model'])
|
self.load_state_dict(state['model'])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -121,7 +121,7 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
def inference(self):
|
def inference(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
self.load_state_dict(state['model'])
|
self.load_state_dict(state['model'])
|
||||||
self.decoder.set_r(state['r'])
|
self.decoder.set_r(state['r'])
|
||||||
|
|
|
@ -50,7 +50,7 @@ def plot_spectrogram(spectrogram,
|
||||||
spectrogram_ = spectrogram_.astype(
|
spectrogram_ = spectrogram_.astype(
|
||||||
np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
|
np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
|
||||||
if ap is not None:
|
if ap is not None:
|
||||||
spectrogram_ = ap._denormalize(spectrogram_) # pylint: disable=protected-access
|
spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
|
||||||
fig = plt.figure(figsize=fig_size)
|
fig = plt.figure(figsize=fig_size)
|
||||||
plt.imshow(spectrogram_, aspect="auto", origin="lower")
|
plt.imshow(spectrogram_, aspect="auto", origin="lower")
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
|
|
|
@ -17,7 +17,7 @@ from TTS.tts.utils.text import make_symbols, phonemes, symbols
|
||||||
|
|
||||||
|
|
||||||
class Synthesizer(object):
|
class Synthesizer(object):
|
||||||
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint, vocoder_config, use_cuda):
|
def __init__(self, tts_checkpoint, tts_config, vocoder_checkpoint=None, vocoder_config=None, use_cuda=False):
|
||||||
"""Encapsulation of tts and vocoder models for inference.
|
"""Encapsulation of tts and vocoder models for inference.
|
||||||
|
|
||||||
TODO: handle multi-speaker and GST inference.
|
TODO: handle multi-speaker and GST inference.
|
||||||
|
@ -25,9 +25,9 @@ class Synthesizer(object):
|
||||||
Args:
|
Args:
|
||||||
tts_checkpoint (str): path to the tts model file.
|
tts_checkpoint (str): path to the tts model file.
|
||||||
tts_config (str): path to the tts config file.
|
tts_config (str): path to the tts config file.
|
||||||
vocoder_checkpoint (str): path to the vocoder model file.
|
vocoder_checkpoint (str, optional): path to the vocoder model file. Defaults to None.
|
||||||
vocoder_config (str): path to the vocoder config file.
|
vocoder_config (str, optional): path to the vocoder config file. Defaults to None.
|
||||||
use_cuda (bool): enable/disable cuda.
|
use_cuda (bool, optional): enable/disable cuda. Defaults to False.
|
||||||
"""
|
"""
|
||||||
self.tts_checkpoint = tts_checkpoint
|
self.tts_checkpoint = tts_checkpoint
|
||||||
self.tts_config = tts_config
|
self.tts_config = tts_config
|
||||||
|
@ -38,6 +38,7 @@ class Synthesizer(object):
|
||||||
self.vocoder_model = None
|
self.vocoder_model = None
|
||||||
self.num_speakers = 0
|
self.num_speakers = 0
|
||||||
self.tts_speakers = None
|
self.tts_speakers = None
|
||||||
|
self.speaker_embedding_dim = None
|
||||||
self.seg = self.get_segmenter("en")
|
self.seg = self.get_segmenter("en")
|
||||||
self.use_cuda = use_cuda
|
self.use_cuda = use_cuda
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
|
@ -116,7 +117,7 @@ class Synthesizer(object):
|
||||||
print(sens)
|
print(sens)
|
||||||
|
|
||||||
speaker_embedding = self.init_speaker(speaker_idx)
|
speaker_embedding = self.init_speaker(speaker_idx)
|
||||||
use_gl = not hasattr(self, 'vocoder_model')
|
use_gl = self.vocoder_model is None
|
||||||
|
|
||||||
for sen in sens:
|
for sen in sens:
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
|
@ -134,17 +135,17 @@ class Synthesizer(object):
|
||||||
speaker_embedding=speaker_embedding)
|
speaker_embedding=speaker_embedding)
|
||||||
if not use_gl:
|
if not use_gl:
|
||||||
# denormalize tts output based on tts audio config
|
# denormalize tts output based on tts audio config
|
||||||
mel_postnet_spec = self.ap._denormalize(mel_postnet_spec.T).T
|
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||||
device_type = "cuda" if self.use_cuda else "cpu"
|
device_type = "cuda" if self.use_cuda else "cpu"
|
||||||
# renormalize spectrogram based on vocoder config
|
# renormalize spectrogram based on vocoder config
|
||||||
vocoder_input = self.vocoder_ap._normalize(mel_postnet_spec.T)
|
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||||
# compute scale factor for possible sample rate mismatch
|
# compute scale factor for possible sample rate mismatch
|
||||||
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
|
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
|
||||||
if scale_factor[1] != 1:
|
if scale_factor[1] != 1:
|
||||||
print(" > interpolating tts model output.")
|
print(" > interpolating tts model output.")
|
||||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||||
else:
|
else:
|
||||||
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0)
|
vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable
|
||||||
# run vocoder model
|
# run vocoder model
|
||||||
# [1, T, C]
|
# [1, T, C]
|
||||||
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
waveform = self.vocoder_model.inference(vocoder_input.to(device_type))
|
||||||
|
|
|
@ -96,7 +96,7 @@ class MelganGenerator(nn.Module):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
layer.remove_weight_norm()
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
self.load_state_dict(state['model'])
|
self.load_state_dict(state['model'])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -158,7 +158,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
return self._get_receptive_field_size(self.layers, self.stacks,
|
return self._get_receptive_field_size(self.layers, self.stacks,
|
||||||
self.kernel_size)
|
self.kernel_size)
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
self.load_state_dict(state['model'])
|
self.load_state_dict(state['model'])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -177,7 +177,7 @@ class Wavegrad(nn.Module):
|
||||||
self.y_conv = weight_norm(self.y_conv)
|
self.y_conv = weight_norm(self.y_conv)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
self.load_state_dict(state['model'])
|
self.load_state_dict(state['model'])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -500,7 +500,7 @@ class WaveRNN(nn.Module):
|
||||||
|
|
||||||
return unfolded
|
return unfolded
|
||||||
|
|
||||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
self.load_state_dict(state['model'])
|
self.load_state_dict(state['model'])
|
||||||
if eval:
|
if eval:
|
||||||
|
|
|
@ -20,7 +20,7 @@ def interpolate_vocoder_input(scale_factor, spec):
|
||||||
torch.tensor: interpolated spectrogram.
|
torch.tensor: interpolated spectrogram.
|
||||||
"""
|
"""
|
||||||
print(" > before interpolation :", spec.shape)
|
print(" > before interpolation :", spec.shape)
|
||||||
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0)
|
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable
|
||||||
spec = torch.nn.functional.interpolate(spec,
|
spec = torch.nn.functional.interpolate(spec,
|
||||||
scale_factor=scale_factor,
|
scale_factor=scale_factor,
|
||||||
recompute_scale_factor=True,
|
recompute_scale_factor=True,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from tests import get_tests_input_path, get_tests_output_path
|
from tests import get_tests_input_path, get_tests_output_path
|
||||||
from TTS.server.synthesizer import Synthesizer
|
from TTS.utils.synthesizer import Synthesizer
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import save_checkpoint
|
from TTS.tts.utils.io import save_checkpoint
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
|
@ -29,7 +29,7 @@ class DemoServerTest(unittest.TestCase):
|
||||||
tts_root_path = get_tests_output_path()
|
tts_root_path = get_tests_output_path()
|
||||||
config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint'])
|
config['tts_checkpoint'] = os.path.join(tts_root_path, config['tts_checkpoint'])
|
||||||
config['tts_config'] = os.path.join(tts_root_path, config['tts_config'])
|
config['tts_config'] = os.path.join(tts_root_path, config['tts_config'])
|
||||||
synthesizer = Synthesizer(config)
|
synthesizer = Synthesizer(config['tts_checkpoint'], config['tts_config'], None, None)
|
||||||
synthesizer.tts("Better this test works!!")
|
synthesizer.tts("Better this test works!!")
|
||||||
|
|
||||||
def test_split_into_sentences(self):
|
def test_split_into_sentences(self):
|
||||||
|
|
Loading…
Reference in New Issue