diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 9dce36fa..6d0497a9 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -161,24 +161,7 @@ class ForwardTTS(BaseTTS): # pylint: disable=dangerous-default-value def __init__(self, config: Coqpit): - super().__init__() - - # don't use isintance not to import recursively - if "Config" in config.__class__.__name__: - if "characters" in config: - # loading from FasrPitchConfig - _, self.config, num_chars = self.get_characters(config) - config.model_args.num_chars = num_chars - self.args = self.config.model_args - else: - # loading from ForwardTTSArgs - self.config = config - self.args = config.model_args - elif isinstance(config, ForwardTTSArgs): - self.args = config - self.config = config - else: - raise ValueError("config must be either a *Config or ForwardTTSArgs") + super().__init__(config) self.max_duration = self.args.max_duration self.use_aligner = self.args.use_aligner @@ -634,7 +617,8 @@ class ForwardTTS(BaseTTS): return outputs, loss_dict - def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use + def _create_logs(self, batch, outputs, ap): + """Create common logger outputs.""" model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] mel_input = batch["mel_input"] @@ -674,11 +658,22 @@ class ForwardTTS(BaseTTS): train_audio = ap.inv_melspectrogram(pred_spec.T) return figures, {"audio": train_audio} + def train_log( + self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int + ) -> None: # pylint: disable=no-self-use + ap = assets["audio_processor"] + figures, audios = self._create_logs(batch, outputs, ap) + logger.train_figures(steps, figures) + logger.train_audios(steps, audios, ap.sample_rate) + def eval_step(self, batch: dict, criterion: nn.Module): return self.train_step(batch, criterion) - def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): - return self.train_log(ap, batch, outputs) + def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None: + ap = assets["audio_processor"] + figures, audios = self._create_logs(batch, outputs, ap) + logger.eval_figures(steps, figures) + logger.eval_audios(steps, audios, ap.sample_rate) def load_checkpoint( self, config, checkpoint_path, eval=False diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 614e42e0..fead67a0 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -1,8 +1,11 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs import FastPitchConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.forward_tts import ForwardTTS +from TTS.utils.audio import AudioProcessor from TTS.utils.manage import ModelManager output_path = os.path.dirname(os.path.abspath(__file__)) @@ -64,7 +67,23 @@ if not config.model_args.use_aligner: f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" ) -# train the model -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger) +# init audio processor +ap = AudioProcessor(**config.audio) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + +# init the model +model = ForwardTTS(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit() diff --git a/recipes/ljspeech/fast_speech/train_fast_speech.py b/recipes/ljspeech/fast_speech/train_fast_speech.py new file mode 100644 index 00000000..56557c26 --- /dev/null +++ b/recipes/ljspeech/fast_speech/train_fast_speech.py @@ -0,0 +1,88 @@ +import os + +from TTS.config import BaseAudioConfig, BaseDatasetConfig +from TTS.trainer import Trainer, TrainingArgs +from TTS.tts.configs import FastSpeechConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.forward_tts import ForwardTTS +from TTS.utils.audio import AudioProcessor +from TTS.utils.manage import ModelManager + +output_path = os.path.dirname(os.path.abspath(__file__)) + +# init configs +dataset_config = BaseDatasetConfig( + name="ljspeech", + meta_file_train="metadata.csv", + # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), + path=os.path.join(output_path, "../LJSpeech-1.1/"), +) + +audio_config = BaseAudioConfig( + sample_rate=22050, + do_trim_silence=True, + trim_db=60.0, + signal_norm=False, + mel_fmin=0.0, + mel_fmax=8000, + spec_gain=1.0, + log_func="np.log", + ref_level_db=20, + preemphasis=0.0, +) + +config = FastSpeechConfig( + run_name="fast_speech_ljspeech", + audio=audio_config, + batch_size=32, + eval_batch_size=16, + num_loader_workers=8, + num_eval_loader_workers=4, + compute_input_seq_cache=True, + compute_f0=False, + run_eval=True, + test_delay_epochs=-1, + epochs=1000, + text_cleaner="english_cleaners", + use_phonemes=True, + use_espeak_phonemes=False, + phoneme_language="en-us", + phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), + print_step=50, + print_eval=False, + mixed_precision=False, + sort_by_audio_len=True, + max_seq_len=500000, + output_path=output_path, + datasets=[dataset_config], +) + +# compute alignments +if not config.model_args.use_aligner: + manager = ModelManager() + model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") + # TODO: make compute_attention python callable + os.system( + f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" + ) + +# init audio processor +ap = AudioProcessor(**config.audio) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + +# init the model +model = ForwardTTS(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) +trainer.fit() diff --git a/recipes/ljspeech/speedy_speech/train_speedy_speech.py b/recipes/ljspeech/speedy_speech/train_speedy_speech.py index 2882468f..27639e6b 100644 --- a/recipes/ljspeech/speedy_speech/train_speedy_speech.py +++ b/recipes/ljspeech/speedy_speech/train_speedy_speech.py @@ -1,18 +1,16 @@ import os from TTS.config import BaseAudioConfig, BaseDatasetConfig -from TTS.trainer import Trainer, TrainingArgs, init_training +from TTS.trainer import Trainer, TrainingArgs from TTS.tts.configs import SpeedySpeechConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.forward_tts import ForwardTTS +from TTS.utils.audio import AudioProcessor from TTS.utils.manage import ModelManager output_path = os.path.dirname(os.path.abspath(__file__)) - -# init configs dataset_config = BaseDatasetConfig( - name="ljspeech", - meta_file_train="metadata.csv", - # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), - path=os.path.join(output_path, "../LJSpeech-1.1/"), + name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/") ) audio_config = BaseAudioConfig( @@ -53,16 +51,32 @@ config = SpeedySpeechConfig( datasets=[dataset_config], ) -# compute alignments -if not config.model_args.use_aligner: - manager = ModelManager() - model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") - # TODO: make compute_attention python callable - os.system( - f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" - ) +# # compute alignments +# if not config.model_args.use_aligner: +# manager = ModelManager() +# model_path, config_path, _ = manager.download_model("tts_models/en/ljspeech/tacotron2-DCA") +# # TODO: make compute_attention python callable +# os.system( +# f"python TTS/bin/compute_attention_masks.py --model_path {model_path} --config_path {config_path} --dataset ljspeech --dataset_metafile metadata.csv --data_path ./recipes/ljspeech/LJSpeech-1.1/ --use_cuda true" +# ) -# train the model -args, config, output_path, _, c_logger, tb_logger = init_training(TrainingArgs(), config) -trainer = Trainer(args, config, output_path, c_logger, tb_logger) +# init audio processor +ap = AudioProcessor(**config.audio.to_dict()) + +# load training samples +train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True) + +# init model +model = ForwardTTS(config) + +# init the trainer and 🚀 +trainer = Trainer( + TrainingArgs(), + config, + output_path, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, + training_assets={"audio_processor": ap}, +) trainer.fit()