From 63625e79af1e13928474fdf964a3322273542939 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 27 Nov 2024 16:12:38 +0100 Subject: [PATCH] refactor: import get_last_checkpoint from trainer.io --- TTS/bin/compute_attention_masks.py | 2 +- TTS/encoder/utils/training.py | 4 ++-- tests/tts_tests/test_neuralhmm_tts_train.py | 2 +- tests/tts_tests/test_overflow_train.py | 2 +- tests/tts_tests/test_speedy_speech_train.py | 2 +- tests/tts_tests/test_tacotron2_d-vectors_train.py | 2 +- tests/tts_tests/test_tacotron2_speaker_emb_train.py | 2 +- tests/tts_tests/test_tacotron2_train.py | 2 +- tests/tts_tests/test_tacotron_train.py | 2 +- tests/tts_tests/test_vits_multilingual_speaker_emb_train.py | 2 +- tests/tts_tests/test_vits_multilingual_train-d_vectors.py | 2 +- tests/tts_tests/test_vits_speaker_emb_train.py | 2 +- tests/tts_tests/test_vits_train.py | 2 +- tests/tts_tests2/test_align_tts_train.py | 2 +- tests/tts_tests2/test_delightful_tts_d-vectors_train.py | 2 +- tests/tts_tests2/test_delightful_tts_emb_spk.py | 2 +- tests/tts_tests2/test_delightful_tts_train.py | 2 +- tests/tts_tests2/test_fast_pitch_speaker_emb_train.py | 2 +- tests/tts_tests2/test_fast_pitch_train.py | 2 +- tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py | 2 +- tests/tts_tests2/test_fastspeech_2_train.py | 2 +- tests/tts_tests2/test_glow_tts_d-vectors_train.py | 2 +- tests/tts_tests2/test_glow_tts_speaker_emb_train.py | 2 +- tests/tts_tests2/test_glow_tts_train.py | 2 +- 24 files changed, 25 insertions(+), 25 deletions(-) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 12719918..535182d2 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -80,7 +80,7 @@ Example run: num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker model = setup_model(C) - model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) + model, _ = load_checkpoint(model, args.model_path, use_cuda=args.use_cuda, eval=True) # data loader preprocessor = importlib.import_module("TTS.tts.datasets.formatters") diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py index cc3a78b0..48629c7a 100644 --- a/TTS/encoder/utils/training.py +++ b/TTS/encoder/utils/training.py @@ -2,9 +2,9 @@ import os from dataclasses import dataclass, field from coqpit import Coqpit -from trainer import TrainerArgs, get_last_checkpoint +from trainer import TrainerArgs from trainer.generic_utils import get_experiment_folder_path, get_git_branch -from trainer.io import copy_model_files +from trainer.io import copy_model_files, get_last_checkpoint from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger diff --git a/tests/tts_tests/test_neuralhmm_tts_train.py b/tests/tts_tests/test_neuralhmm_tts_train.py index 25d9aa81..4789d53d 100644 --- a/tests/tts_tests/test_neuralhmm_tts_train.py +++ b/tests/tts_tests/test_neuralhmm_tts_train.py @@ -4,7 +4,7 @@ import os import shutil import torch -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig diff --git a/tests/tts_tests/test_overflow_train.py b/tests/tts_tests/test_overflow_train.py index 86fa60af..d86bde68 100644 --- a/tests/tts_tests/test_overflow_train.py +++ b/tests/tts_tests/test_overflow_train.py @@ -4,7 +4,7 @@ import os import shutil import torch -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.overflow_config import OverflowConfig diff --git a/tests/tts_tests/test_speedy_speech_train.py b/tests/tts_tests/test_speedy_speech_train.py index 530781ef..2aac7f10 100644 --- a/tests/tts_tests/test_speedy_speech_train.py +++ b/tests/tts_tests/test_speedy_speech_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py index 99ba4349..d2d1d5c3 100644 --- a/tests/tts_tests/test_tacotron2_d-vectors_train.py +++ b/tests/tts_tests/test_tacotron2_d-vectors_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config diff --git a/tests/tts_tests/test_tacotron2_speaker_emb_train.py b/tests/tts_tests/test_tacotron2_speaker_emb_train.py index 5f1bc3fd..83a07d1a 100644 --- a/tests/tts_tests/test_tacotron2_speaker_emb_train.py +++ b/tests/tts_tests/test_tacotron2_speaker_emb_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index 40107070..df0e934d 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py index f7751931..17f1fd46 100644 --- a/tests/tts_tests/test_tacotron_train.py +++ b/tests/tts_tests/test_tacotron_train.py @@ -2,7 +2,7 @@ import glob import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron_config import TacotronConfig diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index 71597ef3..09df7d29 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index fd58db53..7ae09c0e 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py index b7fe197c..69fae21f 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index ea5dc024..78f42d15 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig diff --git a/tests/tts_tests2/test_align_tts_train.py b/tests/tts_tests2/test_align_tts_train.py index 9b0b730d..91c3c35b 100644 --- a/tests/tts_tests2/test_align_tts_train.py +++ b/tests/tts_tests2/test_align_tts_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.align_tts_config import AlignTTSConfig diff --git a/tests/tts_tests2/test_delightful_tts_d-vectors_train.py b/tests/tts_tests2/test_delightful_tts_d-vectors_train.py index 8fc4ea7e..1e5cd49f 100644 --- a/tests/tts_tests2/test_delightful_tts_d-vectors_train.py +++ b/tests/tts_tests2/test_delightful_tts_d-vectors_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig diff --git a/tests/tts_tests2/test_delightful_tts_emb_spk.py b/tests/tts_tests2/test_delightful_tts_emb_spk.py index 6fb70c5f..9bbf7a55 100644 --- a/tests/tts_tests2/test_delightful_tts_emb_spk.py +++ b/tests/tts_tests2/test_delightful_tts_emb_spk.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig diff --git a/tests/tts_tests2/test_delightful_tts_train.py b/tests/tts_tests2/test_delightful_tts_train.py index a917d776..3e6fbd2e 100644 --- a/tests/tts_tests2/test_delightful_tts_train.py +++ b/tests/tts_tests2/test_delightful_tts_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py index 7f79bfca..e6bc9f9f 100644 --- a/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py +++ b/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_fast_pitch_train.py b/tests/tts_tests2/test_fast_pitch_train.py index a525715b..fe87c8b6 100644 --- a/tests/tts_tests2/test_fast_pitch_train.py +++ b/tests/tts_tests2/test_fast_pitch_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py b/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py index 35bda597..735d2fc4 100644 --- a/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py +++ b/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_fastspeech_2_train.py b/tests/tts_tests2/test_fastspeech_2_train.py index dd4b07d2..07fc5a1a 100644 --- a/tests/tts_tests2/test_fastspeech_2_train.py +++ b/tests/tts_tests2/test_fastspeech_2_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_glow_tts_d-vectors_train.py b/tests/tts_tests2/test_glow_tts_d-vectors_train.py index f1cfd436..8236607c 100644 --- a/tests/tts_tests2/test_glow_tts_d-vectors_train.py +++ b/tests/tts_tests2/test_glow_tts_d-vectors_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig diff --git a/tests/tts_tests2/test_glow_tts_speaker_emb_train.py b/tests/tts_tests2/test_glow_tts_speaker_emb_train.py index b1eb6237..4a8bd065 100644 --- a/tests/tts_tests2/test_glow_tts_speaker_emb_train.py +++ b/tests/tts_tests2/test_glow_tts_speaker_emb_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig diff --git a/tests/tts_tests2/test_glow_tts_train.py b/tests/tts_tests2/test_glow_tts_train.py index 0a8e226b..1d7f9135 100644 --- a/tests/tts_tests2/test_glow_tts_train.py +++ b/tests/tts_tests2/test_glow_tts_train.py @@ -3,7 +3,7 @@ import json import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig