mirror of https://github.com/coqui-ai/TTS.git
Merge pull request #553 from Edresson/dev
bug fix in the inference with GlowTTS
This commit is contained in:
commit
26c18b61c9
|
@ -10,7 +10,7 @@ import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from TTS.tts.utils.generic_utils import setup_model
|
from TTS.tts.utils.generic_utils import setup_model, is_tacotron
|
||||||
from TTS.tts.utils.synthesis import synthesis
|
from TTS.tts.utils.synthesis import synthesis
|
||||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||||
from TTS.utils.audio import AudioProcessor
|
from TTS.utils.audio import AudioProcessor
|
||||||
|
@ -125,6 +125,7 @@ if __name__ == "__main__":
|
||||||
model.eval()
|
model.eval()
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
if is_tacotron(C):
|
||||||
model.decoder.set_r(cp['r'])
|
model.decoder.set_r(cp['r'])
|
||||||
|
|
||||||
# load vocoder model
|
# load vocoder model
|
||||||
|
@ -153,7 +154,10 @@ if __name__ == "__main__":
|
||||||
args.speaker_fileid = None
|
args.speaker_fileid = None
|
||||||
|
|
||||||
if args.gst_style is None:
|
if args.gst_style is None:
|
||||||
|
if is_tacotron(C):
|
||||||
gst_style = C.gst['gst_style_input']
|
gst_style = C.gst['gst_style_input']
|
||||||
|
else:
|
||||||
|
gst_style = None
|
||||||
else:
|
else:
|
||||||
# check if gst_style string is a dict, if is dict convert else use string
|
# check if gst_style string is a dict, if is dict convert else use string
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -28,7 +28,6 @@ def split_dataset(items):
|
||||||
return items_eval, items
|
return items_eval, items
|
||||||
return items[:eval_split_size], items[eval_split_size:]
|
return items[:eval_split_size], items[eval_split_size:]
|
||||||
|
|
||||||
|
|
||||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||||
def sequence_mask(sequence_length, max_len=None):
|
def sequence_mask(sequence_length, max_len=None):
|
||||||
if max_len is None:
|
if max_len is None:
|
||||||
|
@ -50,7 +49,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
|
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
|
||||||
MyModel = getattr(MyModel, to_camel(c.model))
|
MyModel = getattr(MyModel, to_camel(c.model))
|
||||||
if c.model.lower() in "tacotron":
|
if c.model.lower() in "tacotron":
|
||||||
model = MyModel(num_chars=num_chars,
|
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||||
num_speakers=num_speakers,
|
num_speakers=num_speakers,
|
||||||
r=c.r,
|
r=c.r,
|
||||||
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
||||||
|
@ -77,7 +76,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
ddc_r=c.ddc_r,
|
ddc_r=c.ddc_r,
|
||||||
speaker_embedding_dim=speaker_embedding_dim)
|
speaker_embedding_dim=speaker_embedding_dim)
|
||||||
elif c.model.lower() == "tacotron2":
|
elif c.model.lower() == "tacotron2":
|
||||||
model = MyModel(num_chars=num_chars,
|
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||||
num_speakers=num_speakers,
|
num_speakers=num_speakers,
|
||||||
r=c.r,
|
r=c.r,
|
||||||
postnet_output_dim=c.audio['num_mels'],
|
postnet_output_dim=c.audio['num_mels'],
|
||||||
|
@ -103,7 +102,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
ddc_r=c.ddc_r,
|
ddc_r=c.ddc_r,
|
||||||
speaker_embedding_dim=speaker_embedding_dim)
|
speaker_embedding_dim=speaker_embedding_dim)
|
||||||
elif c.model.lower() == "glow_tts":
|
elif c.model.lower() == "glow_tts":
|
||||||
model = MyModel(num_chars=num_chars,
|
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||||
hidden_channels=192,
|
hidden_channels=192,
|
||||||
filter_channels=768,
|
filter_channels=768,
|
||||||
filter_channels_dp=256,
|
filter_channels_dp=256,
|
||||||
|
@ -131,7 +130,7 @@ def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def is_tacotron(c):
|
def is_tacotron(c):
|
||||||
return False if c['model'] == 'glow_tts' else True
|
return False if 'glow_tts' in c['model'] else True
|
||||||
|
|
||||||
def check_config_tts(c):
|
def check_config_tts(c):
|
||||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)
|
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts'], restricted=True, val_type=str)
|
||||||
|
|
Loading…
Reference in New Issue