mirror of https://github.com/coqui-ai/TTS.git
Add `base_model` field to `forward_tts` configs
This commit is contained in:
parent
22822cd41c
commit
66732025e1
|
@ -1182,7 +1182,6 @@ def process_args(args, config=None):
|
||||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||||
if not args.best_path:
|
if not args.best_path:
|
||||||
args.best_path = best_model
|
args.best_path = best_model
|
||||||
|
|
||||||
# init config if not already defined
|
# init config if not already defined
|
||||||
if config is None:
|
if config is None:
|
||||||
if args.config_path:
|
if args.config_path:
|
||||||
|
|
|
@ -18,6 +18,10 @@ class FastPitchConfig(BaseTTSConfig):
|
||||||
model (str):
|
model (str):
|
||||||
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
||||||
|
|
||||||
|
base_model (str):
|
||||||
|
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
|
||||||
|
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
|
||||||
|
|
||||||
model_args (Coqpit):
|
model_args (Coqpit):
|
||||||
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||||
|
|
||||||
|
@ -94,9 +98,11 @@ class FastPitchConfig(BaseTTSConfig):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = "forward_tts"
|
model: str = "fast_pitch"
|
||||||
|
base_model: str = "forward_tts"
|
||||||
|
|
||||||
# model specific params
|
# model specific params
|
||||||
model_args: ForwardTTSArgs = field(default_factory=ForwardTTSArgs)
|
model_args: ForwardTTSArgs = ForwardTTSArgs()
|
||||||
|
|
||||||
# multi-speaker settings
|
# multi-speaker settings
|
||||||
use_speaker_embedding: bool = False
|
use_speaker_embedding: bool = False
|
||||||
|
|
|
@ -16,7 +16,11 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (str):
|
model (str):
|
||||||
Model name used for selecting the right model at initialization. Defaults to `fast_pitch`.
|
Model name used for selecting the right model at initialization. Defaults to `speedy_speech`.
|
||||||
|
|
||||||
|
base_model (str):
|
||||||
|
Name of the base model being configured as this model so that 🐸 TTS knows it needs to initiate
|
||||||
|
the base model rather than searching for the `model` implementation. Defaults to `forward_tts`.
|
||||||
|
|
||||||
model_args (Coqpit):
|
model_args (Coqpit):
|
||||||
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
Model class arguments. Check `FastPitchArgs` for more details. Defaults to `FastPitchArgs()`.
|
||||||
|
@ -91,7 +95,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
||||||
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
Maximum input sequence length to be used at training. Larger values result in more VRAM usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str = "forward_tts"
|
model: str = "speedy_speech"
|
||||||
|
base_model: str = "forward_tts"
|
||||||
|
|
||||||
# set model args as SpeedySpeech
|
# set model args as SpeedySpeech
|
||||||
model_args: ForwardTTSArgs = ForwardTTSArgs(
|
model_args: ForwardTTSArgs = ForwardTTSArgs(
|
||||||
|
|
|
@ -4,6 +4,10 @@ from TTS.utils.generic_utils import find_module
|
||||||
|
|
||||||
def setup_model(config):
|
def setup_model(config):
|
||||||
print(" > Using model: {}".format(config.model))
|
print(" > Using model: {}".format(config.model))
|
||||||
|
# fetch the right model implementation.
|
||||||
|
if "base_model" in config and config["base_model"] is not None:
|
||||||
|
MyModel = find_module("TTS.tts.models", config.base_model.lower())
|
||||||
|
else:
|
||||||
MyModel = find_module("TTS.tts.models", config.model.lower())
|
MyModel = find_module("TTS.tts.models", config.model.lower())
|
||||||
# define set of characters used by the model
|
# define set of characters used by the model
|
||||||
if config.characters is not None:
|
if config.characters is not None:
|
||||||
|
|
Loading…
Reference in New Issue