formatting

This commit is contained in:
Eren Golge 2019-09-24 17:20:01 +02:00
parent 23f6743ac9
commit 53d658fb74
1 changed files with 48 additions and 54 deletions

View File

@ -24,8 +24,9 @@ def tts(model,
figures=False): figures=False):
t_1 = time.time() t_1 = time.time()
use_vocoder_model = vocoder_model is not None use_vocoder_model = vocoder_model is not None
waveform, alignment, decoder_outputs, postnet_output, stop_tokens = synthesis( waveform, alignment, _, postnet_output, stop_tokens = synthesis(
model, text, C, use_cuda, ap, speaker_id, False, C.enable_eos_bos_chars) model, text, C, use_cuda, ap, speaker_id, False,
C.enable_eos_bos_chars)
if C.model == "Tacotron" and use_vocoder_model: if C.model == "Tacotron" and use_vocoder_model:
postnet_output = ap.out_linear_to_mel(postnet_output.T).T postnet_output = ap.out_linear_to_mel(postnet_output.T).T
# correct if there is a scale difference b/w two models # correct if there is a scale difference b/w two models
@ -45,13 +46,10 @@ def tts(model,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument('text', type=str, help='Text to generate speech.')
'text', type=str, help='Text to generate speech.') parser.add_argument('config_path',
parser.add_argument(
'config_path',
type=str, type=str,
help='Path to model config file.' help='Path to model config file.')
)
parser.add_argument( parser.add_argument(
'model_path', 'model_path',
type=str, type=str,
@ -62,8 +60,10 @@ if __name__ == "__main__":
type=str, type=str,
help='Path to save final wav file.', help='Path to save final wav file.',
) )
parser.add_argument( parser.add_argument('--use_cuda',
'--use_cuda', type=bool, help='Run model on CUDA.', default=False) type=bool,
help='Run model on CUDA.',
default=False)
parser.add_argument( parser.add_argument(
'--vocoder_path', '--vocoder_path',
type=str, type=str,
@ -71,8 +71,7 @@ if __name__ == "__main__":
'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).', 'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).',
default="", default="",
) )
parser.add_argument( parser.add_argument('--vocoder_config_path',
'--vocoder_config_path',
type=str, type=str,
help='Path to vocoder model config file.', help='Path to vocoder model config file.',
default="") default="")
@ -81,18 +80,15 @@ if __name__ == "__main__":
type=bool, type=bool,
help="If True, vocoder model uses faster batch processing.", help="If True, vocoder model uses faster batch processing.",
default=True) default=True)
parser.add_argument( parser.add_argument('--speakers_json',
'--speakers_json',
type=str, type=str,
help="JSON file for multi-speaker model.", help="JSON file for multi-speaker model.",
default="" default="")
)
parser.add_argument( parser.add_argument(
'--speaker_id', '--speaker_id',
type=int, type=int,
help="target speaker_id if the model is multi-speaker.", help="target speaker_id if the model is multi-speaker.",
default=None default=None)
)
args = parser.parse_args() args = parser.parse_args()
if args.vocoder_path != "": if args.vocoder_path != "":
@ -128,8 +124,7 @@ if __name__ == "__main__":
VC = load_config(args.vocoder_config_path) VC = load_config(args.vocoder_config_path)
ap_vocoder = AudioProcessor(**VC.audio) ap_vocoder = AudioProcessor(**VC.audio)
bits = 10 bits = 10
vocoder_model = VocoderModel( vocoder_model = VocoderModel(rnn_dims=512,
rnn_dims=512,
fc_dims=512, fc_dims=512,
mode=VC.mode, mode=VC.mode,
mulaw=VC.mulaw, mulaw=VC.mulaw,
@ -142,8 +137,7 @@ if __name__ == "__main__":
hop_length=ap.hop_length, hop_length=ap.hop_length,
sample_rate=ap.sample_rate, sample_rate=ap.sample_rate,
use_aux_net=True, use_aux_net=True,
use_upsample_net=True use_upsample_net=True)
)
check = torch.load(args.vocoder_path) check = torch.load(args.vocoder_path)
vocoder_model.load_state_dict(check['model']) vocoder_model.load_state_dict(check['model'])
@ -157,8 +151,7 @@ if __name__ == "__main__":
# synthesize voice # synthesize voice
print(" > Text: {}".format(args.text)) print(" > Text: {}".format(args.text))
_, _, _, wav = tts( _, _, _, wav = tts(model,
model,
vocoder_model, vocoder_model,
C, C,
VC, VC,
@ -172,7 +165,8 @@ if __name__ == "__main__":
# save the results # save the results
file_name = args.text.replace(" ", "_") file_name = args.text.replace(" ", "_")
file_name = file_name.translate(str.maketrans('', '', string.punctuation.replace('_', '')))+'.wav' file_name = file_name.translate(
str.maketrans('', '', string.punctuation.replace('_', ''))) + '.wav'
out_path = os.path.join(args.out_path, file_name) out_path = os.path.join(args.out_path, file_name)
print(" > Saving output to {}".format(out_path)) print(" > Saving output to {}".format(out_path))
ap.save_wav(wav, out_path) ap.save_wav(wav, out_path)