diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 5828411c..33724919 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -10,6 +10,8 @@ import torch from torch.utils.data import DataLoader from trainer.torch import NoamLR +from trainer.torch import NoamLR + from TTS.speaker_encoder.dataset import SpeakerEncoderDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_speaker_encoder_model diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 31813712..1bca7430 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os from dataclasses import dataclass, field diff --git a/TTS/bin/train_vocoder.py b/TTS/bin/train_vocoder.py index 32ecd7bd..1745d6ab 100644 --- a/TTS/bin/train_vocoder.py +++ b/TTS/bin/train_vocoder.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass, field import os from dataclasses import dataclass, field diff --git a/TTS/speaker_encoder/utils/training.py b/TTS/speaker_encoder/utils/training.py index 0bc72af8..c07915c9 100644 --- a/TTS/speaker_encoder/utils/training.py +++ b/TTS/speaker_encoder/utils/training.py @@ -1,3 +1,5 @@ +from asyncio.log import logger +from dataclasses import dataclass, field import os from dataclasses import dataclass, field @@ -7,10 +9,12 @@ from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger from TTS.config import load_config, register_config +from trainer import TrainerArgs, get_last_checkpoint from TTS.tts.utils.text.characters import parse_symbols from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch from TTS.utils.io import copy_model_files -from trainer import get_last_checkpoint +from trainer.logging import logger_factory +from trainer.logging.console_logger import ConsoleLogger @dataclass