Update `tts.models.setup_model`

This commit is contained in:
Eren Gölge 2021-06-18 13:36:47 +02:00
parent cae702980f
commit d10f9c5676
1 changed files with 38 additions and 106 deletions

View File

@ -1,110 +1,42 @@
from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
from TTS.utils.generic_utils import find_module
def setup_model(num_chars, num_speakers, c, d_vector_dim=None):
print(" > Using model: {}".format(c.model))
MyModel = find_module("TTS.tts.models", c.model.lower())
if c.model.lower() in "tacotron":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
decoder_output_dim=c.audio["num_mels"],
use_gst=c.use_gst,
gst=c.gst,
memory_size=c.memory_size,
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
d_vector_dim=d_vector_dim,
max_decoder_steps=c.max_decoder_steps,
)
elif c.model.lower() == "tacotron2":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
use_gst=c.use_gst,
gst=c.gst,
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
prenet_dropout_at_inference=c.prenet_dropout_at_inference,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
d_vector_dim=d_vector_dim,
max_decoder_steps=c.max_decoder_steps,
)
elif c.model.lower() == "glow_tts":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels_enc=c["hidden_channels_encoder"],
hidden_channels_dec=c["hidden_channels_decoder"],
hidden_channels_dp=c["hidden_channels_duration_predictor"],
out_channels=c.audio["num_mels"],
encoder_type=c.encoder_type,
encoder_params=c.encoder_params,
use_encoder_prenet=c["use_encoder_prenet"],
inference_noise_scale=c.inference_noise_scale,
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=1,
num_block_layers=4,
dropout_p_dec=0.05,
num_speakers=num_speakers,
c_in_channels=0,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
mean_only=True,
d_vector_dim=d_vector_dim,
)
elif c.model.lower() == "speedy_speech":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio["num_mels"],
hidden_channels=c["hidden_channels"],
positional_encoding=c["positional_encoding"],
encoder_type=c["encoder_type"],
encoder_params=c["encoder_params"],
decoder_type=c["decoder_type"],
decoder_params=c["decoder_params"],
c_in_channels=0,
)
elif c.model.lower() == "align_tts":
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio["num_mels"],
hidden_channels=c["hidden_channels"],
hidden_channels_dp=c["hidden_channels_dp"],
encoder_type=c["encoder_type"],
encoder_params=c["encoder_params"],
decoder_type=c["decoder_type"],
decoder_params=c["decoder_params"],
c_in_channels=0,
)
def setup_model(config):
print(" > Using model: {}".format(config.model))
MyModel = find_module("TTS.tts.models", config.model.lower())
# define set of characters used by the model
if config.characters is not None:
# set characters from config
symbols, phonemes = make_symbols(**config.characters.to_dict()) # pylint: disable=redefined-outer-name
else:
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
# use default characters and assign them to config
config.characters = parse_symbols()
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
# consider special `blank` character if `add_blank` is set True
num_chars = num_chars + getattr(config, "add_blank", False)
config.num_chars = num_chars
# compatibility fix
if "model_params" in config:
config.model_params.num_chars = num_chars
if "model_args" in config:
config.model_args.num_chars = num_chars
model = MyModel(config)
return model
# TODO; class registery
# def import_models(models_dir, namespace):
# for file in os.listdir(models_dir):
# path = os.path.join(models_dir, file)
# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
# model_name = file[: file.find(".py")] if file.endswith(".py") else file
# importlib.import_module(namespace + "." + model_name)
#
#
## automatically import any Python files in the models/ directory
# models_dir = os.path.dirname(__file__)
# import_models(models_dir, "TTS.tts.models")