mirror of https://github.com/coqui-ai/TTS.git
update synthesize.py
This commit is contained in:
parent
69f525f17d
commit
7080b5fb34
|
@ -1,53 +1,41 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
# pylint: disable=redefined-outer-name, unused-argument
|
# pylint: disable=redefined-outer-name, unused-argument
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import argparse
|
|
||||||
import torch
|
|
||||||
import json
|
|
||||||
import string
|
import string
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model
|
||||||
from TTS.tts.utils.io import load_config
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.tts.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
from TTS.utils.io import load_config
|
||||||
|
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||||
|
|
||||||
|
|
||||||
def tts(model,
|
def tts(model, vocoder_model, text, CONFIG, use_cuda, ap, use_gl, speaker_id):
|
||||||
vocoder_model,
|
|
||||||
C,
|
|
||||||
VC,
|
|
||||||
text,
|
|
||||||
ap,
|
|
||||||
ap_vocoder,
|
|
||||||
use_cuda,
|
|
||||||
batched_vocoder,
|
|
||||||
speaker_id=None,
|
|
||||||
figures=False):
|
|
||||||
t_1 = time.time()
|
t_1 = time.time()
|
||||||
use_vocoder_model = vocoder_model is not None
|
waveform, _, _, mel_postnet_spec, stop_tokens, _ = synthesis(model, text, CONFIG, use_cuda, ap, speaker_id, None, False, CONFIG.enable_eos_bos_chars, use_gl)
|
||||||
waveform, alignment, _, postnet_output, stop_tokens, _ = synthesis(
|
if CONFIG.model == "Tacotron" and not use_gl:
|
||||||
model, text, C, use_cuda, ap, speaker_id, style_wav=False,
|
mel_postnet_spec = ap.out_linear_to_mel(mel_postnet_spec.T).T
|
||||||
truncated=False, enable_eos_bos_chars=C.enable_eos_bos_chars,
|
if not use_gl:
|
||||||
use_griffin_lim=(not use_vocoder_model), do_trim_silence=True)
|
waveform = vocoder_model.inference(torch.FloatTensor(mel_postnet_spec.T).unsqueeze(0))
|
||||||
|
if use_cuda and not use_gl:
|
||||||
if C.model == "Tacotron" and use_vocoder_model:
|
waveform = waveform.cpu()
|
||||||
postnet_output = ap.out_linear_to_mel(postnet_output.T).T
|
if not use_gl:
|
||||||
# correct if there is a scale difference b/w two models
|
waveform = waveform.numpy()
|
||||||
if use_vocoder_model:
|
waveform = waveform.squeeze()
|
||||||
postnet_output = ap._denormalize(postnet_output)
|
rtf = (time.time() - t_1) / (len(waveform) / ap.sample_rate)
|
||||||
postnet_output = ap_vocoder._normalize(postnet_output)
|
tps = (time.time() - t_1) / len(waveform)
|
||||||
vocoder_input = torch.FloatTensor(postnet_output.T).unsqueeze(0)
|
print(" > Run-time: {}".format(time.time() - t_1))
|
||||||
waveform = vocoder_model.generate(
|
print(" > Real-time factor: {}".format(rtf))
|
||||||
vocoder_input.cuda() if use_cuda else vocoder_input,
|
print(" > Time per step: {}".format(tps))
|
||||||
batched=batched_vocoder,
|
return waveform
|
||||||
target=8000,
|
|
||||||
overlap=400)
|
|
||||||
print(" > Run-time: {}".format(time.time() - t_1))
|
|
||||||
return alignment, postnet_output, stop_tokens, waveform
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -100,10 +88,6 @@ if __name__ == "__main__":
|
||||||
default=None)
|
default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.vocoder_path != "":
|
|
||||||
assert args.use_cuda, " [!] Enable cuda for vocoder."
|
|
||||||
from WaveRNN.models.wavernn import Model as VocoderModel
|
|
||||||
|
|
||||||
# load the config
|
# load the config
|
||||||
C = load_config(args.config_path)
|
C = load_config(args.config_path)
|
||||||
C.forward_attn_mask = True
|
C.forward_attn_mask = True
|
||||||
|
@ -125,7 +109,7 @@ if __name__ == "__main__":
|
||||||
# load the model
|
# load the model
|
||||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||||
model = setup_model(num_chars, num_speakers, C)
|
model = setup_model(num_chars, num_speakers, C)
|
||||||
cp = torch.load(args.model_path)
|
cp = torch.load(args.model_path, map_location=torch.device('cpu'))
|
||||||
model.load_state_dict(cp['model'])
|
model.load_state_dict(cp['model'])
|
||||||
model.eval()
|
model.eval()
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
|
@ -135,46 +119,20 @@ if __name__ == "__main__":
|
||||||
# load vocoder model
|
# load vocoder model
|
||||||
if args.vocoder_path != "":
|
if args.vocoder_path != "":
|
||||||
VC = load_config(args.vocoder_config_path)
|
VC = load_config(args.vocoder_config_path)
|
||||||
ap_vocoder = AudioProcessor(**VC.audio)
|
vocoder_model = setup_generator(VC)
|
||||||
bits = 10
|
vocoder_model.load_state_dict(torch.load(args.vocoder_path, map_location="cpu")["model"])
|
||||||
vocoder_model = VocoderModel(rnn_dims=512,
|
vocoder_model.remove_weight_norm()
|
||||||
fc_dims=512,
|
|
||||||
mode=VC.mode,
|
|
||||||
mulaw=VC.mulaw,
|
|
||||||
pad=VC.pad,
|
|
||||||
upsample_factors=VC.upsample_factors,
|
|
||||||
feat_dims=VC.audio["num_mels"],
|
|
||||||
compute_dims=128,
|
|
||||||
res_out_dims=128,
|
|
||||||
res_blocks=10,
|
|
||||||
hop_length=ap.hop_length,
|
|
||||||
sample_rate=ap.sample_rate,
|
|
||||||
use_aux_net=True,
|
|
||||||
use_upsample_net=True)
|
|
||||||
|
|
||||||
check = torch.load(args.vocoder_path)
|
|
||||||
vocoder_model.load_state_dict(check['model'])
|
|
||||||
vocoder_model.eval()
|
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
vocoder_model.cuda()
|
vocoder_model.cuda()
|
||||||
|
vocoder_model.eval()
|
||||||
else:
|
else:
|
||||||
vocoder_model = None
|
vocoder_model = None
|
||||||
VC = None
|
VC = None
|
||||||
ap_vocoder = None
|
|
||||||
|
|
||||||
# synthesize voice
|
# synthesize voice
|
||||||
|
use_griffin_lim = args.vocoder_path == ""
|
||||||
print(" > Text: {}".format(args.text))
|
print(" > Text: {}".format(args.text))
|
||||||
_, _, _, wav = tts(model,
|
wav = tts(model, vocoder_model, args.text, C, args.use_cuda, ap, use_griffin_lim, args.speaker_id)
|
||||||
vocoder_model,
|
|
||||||
C,
|
|
||||||
VC,
|
|
||||||
args.text,
|
|
||||||
ap,
|
|
||||||
ap_vocoder,
|
|
||||||
args.use_cuda,
|
|
||||||
args.batched_vocoder,
|
|
||||||
speaker_id=args.speaker_id,
|
|
||||||
figures=False)
|
|
||||||
|
|
||||||
# save the results
|
# save the results
|
||||||
file_name = args.text.replace(" ", "_")
|
file_name = args.text.replace(" ", "_")
|
||||||
|
|
|
@ -121,6 +121,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def inference(self, c):
|
def inference(self, c):
|
||||||
c = c.to(self.first_conv.weight.device)
|
c = c.to(self.first_conv.weight.device)
|
||||||
c = torch.nn.functional.pad(
|
c = torch.nn.functional.pad(
|
||||||
|
|
Loading…
Reference in New Issue