diff --git a/TTS/tts/models/__init__.py b/TTS/tts/models/__init__.py index 2a951267..c6390beb 100644 --- a/TTS/tts/models/__init__.py +++ b/TTS/tts/models/__init__.py @@ -1,110 +1,42 @@ +from TTS.tts.utils.text.symbols import make_symbols, parse_symbols from TTS.utils.generic_utils import find_module -def setup_model(num_chars, num_speakers, c, d_vector_dim=None): - print(" > Using model: {}".format(c.model)) - MyModel = find_module("TTS.tts.models", c.model.lower()) - if c.model.lower() in "tacotron": - model = MyModel( - num_chars=num_chars + getattr(c, "add_blank", False), - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=int(c.audio["fft_size"] / 2 + 1), - decoder_output_dim=c.audio["num_mels"], - use_gst=c.use_gst, - gst=c.gst, - memory_size=c.memory_size, - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - prenet_dropout_at_inference=c.prenet_dropout_at_inference, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r, - d_vector_dim=d_vector_dim, - max_decoder_steps=c.max_decoder_steps, - ) - elif c.model.lower() == "tacotron2": - model = MyModel( - num_chars=num_chars + getattr(c, "add_blank", False), - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=c.audio["num_mels"], - decoder_output_dim=c.audio["num_mels"], - use_gst=c.use_gst, - gst=c.gst, - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - prenet_dropout_at_inference=c.prenet_dropout_at_inference, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r, - d_vector_dim=d_vector_dim, - max_decoder_steps=c.max_decoder_steps, - ) - elif c.model.lower() == "glow_tts": - model = MyModel( - num_chars=num_chars + getattr(c, "add_blank", False), - hidden_channels_enc=c["hidden_channels_encoder"], - hidden_channels_dec=c["hidden_channels_decoder"], - hidden_channels_dp=c["hidden_channels_duration_predictor"], - out_channels=c.audio["num_mels"], - encoder_type=c.encoder_type, - encoder_params=c.encoder_params, - use_encoder_prenet=c["use_encoder_prenet"], - inference_noise_scale=c.inference_noise_scale, - num_flow_blocks_dec=12, - kernel_size_dec=5, - dilation_rate=1, - num_block_layers=4, - dropout_p_dec=0.05, - num_speakers=num_speakers, - c_in_channels=0, - num_splits=4, - num_squeeze=2, - sigmoid_scale=False, - mean_only=True, - d_vector_dim=d_vector_dim, - ) - elif c.model.lower() == "speedy_speech": - model = MyModel( - num_chars=num_chars + getattr(c, "add_blank", False), - out_channels=c.audio["num_mels"], - hidden_channels=c["hidden_channels"], - positional_encoding=c["positional_encoding"], - encoder_type=c["encoder_type"], - encoder_params=c["encoder_params"], - decoder_type=c["decoder_type"], - decoder_params=c["decoder_params"], - c_in_channels=0, - ) - elif c.model.lower() == "align_tts": - model = MyModel( - num_chars=num_chars + getattr(c, "add_blank", False), - out_channels=c.audio["num_mels"], - hidden_channels=c["hidden_channels"], - hidden_channels_dp=c["hidden_channels_dp"], - encoder_type=c["encoder_type"], - encoder_params=c["encoder_params"], - decoder_type=c["decoder_type"], - decoder_params=c["decoder_params"], - c_in_channels=0, - ) +def setup_model(config): + print(" > Using model: {}".format(config.model)) + + MyModel = find_module("TTS.tts.models", config.model.lower()) + # define set of characters used by the model + if config.characters is not None: + # set characters from config + symbols, phonemes = make_symbols(**config.characters.to_dict()) # pylint: disable=redefined-outer-name + else: + from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel + + # use default characters and assign them to config + config.characters = parse_symbols() + num_chars = len(phonemes) if config.use_phonemes else len(symbols) + # consider special `blank` character if `add_blank` is set True + num_chars = num_chars + getattr(config, "add_blank", False) + config.num_chars = num_chars + # compatibility fix + if "model_params" in config: + config.model_params.num_chars = num_chars + if "model_args" in config: + config.model_args.num_chars = num_chars + model = MyModel(config) return model + + +# TODO; class registery +# def import_models(models_dir, namespace): +# for file in os.listdir(models_dir): +# path = os.path.join(models_dir, file) +# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): +# model_name = file[: file.find(".py")] if file.endswith(".py") else file +# importlib.import_module(namespace + "." + model_name) +# +# +## automatically import any Python files in the models/ directory +# models_dir = os.path.dirname(__file__) +# import_models(models_dir, "TTS.tts.models")