Add `base_model` field to `forward_tts` configs

This commit is contained in:
Eren Gölge 2021-09-10 17:23:48 +00:00
parent 22822cd41c
commit 66732025e1
4 changed files with 20 additions and 6 deletions

View File

@ -1182,7 +1182,6 @@ def process_args(args, config=None):
args.restore_path, best_model = get_last_checkpoint(args.continue_path) args.restore_path, best_model = get_last_checkpoint(args.continue_path)
if not args.best_path: if not args.best_path:
args.best_path = best_model args.best_path = best_model
# init config if not already defined # init config if not already defined
if config is None: if config is None:
if args.config_path: if args.config_path:

View File

@ -18,6 +18,10 @@ class FastPitchConfig(BaseTTSConfig):
model (str): model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit): model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
@ -94,9 +98,11 @@ class FastPitchConfig(BaseTTSConfig):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage. Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
""" """
model: str = "forward_tts" model: str = "fast_pitch"
base_model: str = "forward_tts"
# model specific params # model specific params
model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs) model_args: ForwardTTSArgs = ForwardTTSArgs()
# multi-speaker settings # multi-speaker settings
use_speaker_embedding: bool = False use_speaker_embedding: bool = False

View File

@ -16,7 +16,11 @@ class SpeedySpeechConfig(BaseTTSConfig):
Args: Args:
model (str): model (str):
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`. Model name used for selecting the right model at initialization. Defaults to `speedy_speech`.
base_model (str):
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
model_args (Coqpit): model_args (Coqpit):
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`. Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
@ -91,7 +95,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
Maximum input sequence length to be used at training. Larger values result in more VRAM usage. Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
""" """
model: str = "forward_tts" model: str = "speedy_speech"
base_model: str = "forward_tts"
# set model args as SpeedySpeech # set model args as SpeedySpeech
model_args: ForwardTTSArgs = ForwardTTSArgs( model_args: ForwardTTSArgs = ForwardTTSArgs(

View File

@ -4,7 +4,11 @@ from TTS.utils.generic_utils import find_module
def setup_model(config): def setup_model(config):
print(" > Using model: {}".format(config.model)) print(" > Using model: {}".format(config.model))
MyModel = find_module("TTS.tts.models", config.model.lower()) # fetch the right model implementation.
if "base_model" in config and config["base_model"] is not None:
MyModel = find_module("TTS.tts.models", config.base_model.lower())
else:
MyModel = find_module("TTS.tts.models", config.model.lower())
# define set of characters used by the model # define set of characters used by the model
if config.characters is not None: if config.characters is not None:
# set characters from config # set characters from config