mirror of https://github.com/coqui-ai/TTS.git
typing annotation for the trainer
This commit is contained in:
parent
57cdddef16
commit
6bf6543df8
275
TTS/trainer.py
275
TTS/trainer.py
|
@ -65,12 +65,12 @@ class TrainerTTS:
|
||||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
args,
|
args: Union[Coqpit, Namespace],
|
||||||
config,
|
config: Coqpit,
|
||||||
c_logger,
|
c_logger: ConsoleLogger,
|
||||||
tb_logger,
|
tb_logger: TensorboardLogger,
|
||||||
model=None,
|
model: nn.Module = None,
|
||||||
output_path=None):
|
output_path: str = None) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.config = config
|
self.config = config
|
||||||
self.c_logger = c_logger
|
self.c_logger = c_logger
|
||||||
|
@ -88,43 +88,52 @@ class TrainerTTS:
|
||||||
self.keep_avg_train = None
|
self.keep_avg_train = None
|
||||||
self.keep_avg_eval = None
|
self.keep_avg_eval = None
|
||||||
|
|
||||||
|
log_file = os.path.join(self.output_path,
|
||||||
|
f"trainer_{args.rank}_log.txt")
|
||||||
|
self._setup_logger_config(log_file)
|
||||||
|
|
||||||
# model, audio processor, datasets, loss
|
# model, audio processor, datasets, loss
|
||||||
# init audio processor
|
# init audio processor
|
||||||
self.ap = AudioProcessor(**config.audio.to_dict())
|
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
||||||
|
|
||||||
# init character processor
|
# init character processor
|
||||||
self.model_characters = self.init_character_processor()
|
self.model_characters = self.get_character_processor(self.config)
|
||||||
|
|
||||||
# load dataset samples
|
# load dataset samples
|
||||||
self.data_train, self.data_eval = load_meta_data(config.datasets)
|
self.data_train, self.data_eval = load_meta_data(self.config.datasets)
|
||||||
|
|
||||||
# default speaker manager
|
# default speaker manager
|
||||||
self.speaker_manager = self.init_speaker_manager()
|
self.speaker_manager = self.get_speaker_manager(
|
||||||
|
self.config, args.restore_path, self.config.output_path, self.data_train)
|
||||||
|
|
||||||
# init TTS model
|
# init TTS model
|
||||||
if model is not None:
|
if model is not None:
|
||||||
self.model = model
|
self.model = model
|
||||||
else:
|
else:
|
||||||
self.model = self.init_model()
|
self.model = self.get_model(
|
||||||
|
len(self.model_characters), self.speaker_manager.num_speakers,
|
||||||
|
self.config, self.speaker_manager.x_vector_dim
|
||||||
|
if self.speaker_manager.x_vectors else None)
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.init_criterion()
|
self.criterion = self.get_criterion(self.config)
|
||||||
|
|
||||||
|
if self.use_cuda:
|
||||||
|
self.model.cuda()
|
||||||
|
self.criterion.cuda()
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
init_distributed(args.rank, self.num_gpus, args.group_id,
|
init_distributed(args.rank, self.num_gpus, args.group_id,
|
||||||
config.distributed["backend"],
|
self.config.distributed["backend"],
|
||||||
config.distributed["url"])
|
self.config.distributed["url"])
|
||||||
|
|
||||||
# scalers for mixed precision training
|
# scalers for mixed precision training
|
||||||
self.scaler = torch.cuda.amp.GradScaler(
|
self.scaler = torch.cuda.amp.GradScaler(
|
||||||
) if config.mixed_precision else None
|
) if self.config.mixed_precision and self.use_cuda else None
|
||||||
|
|
||||||
# setup optimizer
|
# setup optimizer
|
||||||
self.optimizer = self.init_optimizer(self.model)
|
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||||
|
|
||||||
# setup scheduler
|
|
||||||
self.scheduler = self.init_scheduler(self.config, self.optimizer)
|
|
||||||
|
|
||||||
if self.args.restore_path:
|
if self.args.restore_path:
|
||||||
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
|
self.model, self.optimizer, self.scaler, self.restore_step = self.restore_model(
|
||||||
|
@ -144,64 +153,66 @@ class TrainerTTS:
|
||||||
logging.info("\n > Model has {} parameters".format(num_params),
|
logging.info("\n > Model has {} parameters".format(num_params),
|
||||||
flush=True)
|
flush=True)
|
||||||
|
|
||||||
def init_model(self):
|
@staticmethod
|
||||||
model = setup_model(
|
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
||||||
len(self.model_characters),
|
x_vector_dim: int) -> nn.Module:
|
||||||
self.speaker_manager.num_speakers,
|
model = setup_model(num_chars, num_speakers, config, x_vector_dim)
|
||||||
self.config,
|
|
||||||
self.speaker_manager.x_vector_dim
|
|
||||||
if self.speaker_manager.x_vectors else None,
|
|
||||||
)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def init_optimizer(self, model):
|
@staticmethod
|
||||||
optimizer_name = self.config.optimizer
|
def get_optimizer(model: nn.Module, config: Coqpit) -> torch.optim.Optimizer:
|
||||||
optimizer_params = self.config.optimizer_params
|
optimizer_name = config.optimizer
|
||||||
|
optimizer_params = config.optimizer_params
|
||||||
if optimizer_name.lower() == "radam":
|
if optimizer_name.lower() == "radam":
|
||||||
module = importlib.import_module("TTS.utils.radam")
|
module = importlib.import_module("TTS.utils.radam")
|
||||||
optimizer = getattr(module, "RAdam")
|
optimizer = getattr(module, "RAdam")
|
||||||
else:
|
else:
|
||||||
optimizer = getattr(torch.optim, optimizer_name)
|
optimizer = getattr(torch.optim, optimizer_name)
|
||||||
return optimizer(model.parameters(),
|
return optimizer(model.parameters(), lr=config.lr, **optimizer_params)
|
||||||
lr=self.config.lr,
|
|
||||||
**optimizer_params)
|
|
||||||
|
|
||||||
def init_character_processor(self):
|
@staticmethod
|
||||||
|
def get_character_processor(config: Coqpit) -> str:
|
||||||
# setup custom characters if set in config file.
|
# setup custom characters if set in config file.
|
||||||
# TODO: implement CharacterProcessor
|
# TODO: implement CharacterProcessor
|
||||||
if self.config.characters is not None:
|
if config.characters is not None:
|
||||||
symbols, phonemes = make_symbols(
|
symbols, phonemes = make_symbols(**config.characters.to_dict())
|
||||||
**self.config.characters.to_dict())
|
|
||||||
else:
|
else:
|
||||||
from TTS.tts.utils.text.symbols import symbols, phonemes
|
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||||
model_characters = phonemes if self.config.use_phonemes else symbols
|
model_characters = phonemes if config.use_phonemes else symbols
|
||||||
return model_characters
|
return model_characters
|
||||||
|
|
||||||
def init_speaker_manager(self, restore_path: str = "", out_path: str = ""):
|
@staticmethod
|
||||||
|
def get_speaker_manager(config: Coqpit,
|
||||||
|
restore_path: str = "",
|
||||||
|
out_path: str = "",
|
||||||
|
data_train: List = []) -> SpeakerManager:
|
||||||
speaker_manager = SpeakerManager()
|
speaker_manager = SpeakerManager()
|
||||||
if restore_path:
|
if config.use_speaker_embedding:
|
||||||
speakers_file = os.path.join(os.path.dirname(restore_path),
|
if restore_path:
|
||||||
"speaker.json")
|
speakers_file = os.path.join(os.path.dirname(restore_path),
|
||||||
if not os.path.exists(speakers_file):
|
"speaker.json")
|
||||||
logging.info(
|
if not os.path.exists(speakers_file):
|
||||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
print(
|
||||||
)
|
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
||||||
speakers_file = self.config.external_speaker_embedding_file
|
)
|
||||||
|
speakers_file = config.external_speaker_embedding_file
|
||||||
|
|
||||||
if self.config.use_external_speaker_embedding_file:
|
if config.use_external_speaker_embedding_file:
|
||||||
speaker_manager.load_x_vectors_file(speakers_file)
|
speaker_manager.load_x_vectors_file(speakers_file)
|
||||||
|
else:
|
||||||
|
speaker_manager.load_ids_file(speakers_file)
|
||||||
|
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
|
||||||
|
speaker_manager.load_x_vectors_file(
|
||||||
|
config.external_speaker_embedding_file)
|
||||||
else:
|
else:
|
||||||
self.speaker_manage.load_speaker_mapping(speakers_file)
|
speaker_manager.parse_speakers_from_items(data_train)
|
||||||
elif self.config.use_external_speaker_embedding_file and self.config.external_speaker_embedding_file:
|
file_path = os.path.join(out_path, "speakers.json")
|
||||||
speaker_manager.load_x_vectors_file(
|
speaker_manager.save_ids_file(file_path)
|
||||||
self.config.external_speaker_embedding_file)
|
|
||||||
else:
|
|
||||||
speaker_manager.parse_speakers_from_items(self.data_train)
|
|
||||||
file_path = os.path.join(out_path, "speakers.json")
|
|
||||||
speaker_manager.save_ids_file(file_path)
|
|
||||||
return speaker_manager
|
return speaker_manager
|
||||||
|
|
||||||
def init_scheduler(self, config, optimizer):
|
@staticmethod
|
||||||
|
def get_scheduler(config: Coqpit,
|
||||||
|
optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
||||||
lr_scheduler = config.lr_scheduler
|
lr_scheduler = config.lr_scheduler
|
||||||
lr_scheduler_params = config.lr_scheduler_params
|
lr_scheduler_params = config.lr_scheduler_params
|
||||||
if lr_scheduler is None:
|
if lr_scheduler is None:
|
||||||
|
@ -213,17 +224,20 @@ class TrainerTTS:
|
||||||
scheduler = getattr(torch.optim, lr_scheduler)
|
scheduler = getattr(torch.optim, lr_scheduler)
|
||||||
return scheduler(optimizer, **lr_scheduler_params)
|
return scheduler(optimizer, **lr_scheduler_params)
|
||||||
|
|
||||||
def init_criterion(self):
|
@staticmethod
|
||||||
return setup_loss(self.config)
|
def get_criterion(config: Coqpit) -> nn.Module:
|
||||||
|
return setup_loss(config)
|
||||||
|
|
||||||
def restore_model(self,
|
def restore_model(
|
||||||
config,
|
self,
|
||||||
restore_path,
|
config: Coqpit,
|
||||||
model,
|
restore_path: str,
|
||||||
optimizer,
|
model: nn.Module,
|
||||||
scaler=None):
|
optimizer: torch.optim.Optimizer,
|
||||||
logging.info(f" > Restoring from {os.path.basename(restore_path)}...")
|
scaler: torch.cuda.amp.GradScaler = None
|
||||||
checkpoint = torch.load(restore_path, map_location="cpu")
|
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
|
||||||
|
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||||
|
checkpoint = torch.load(restore_path)
|
||||||
try:
|
try:
|
||||||
logging.info(" > Restoring Model...")
|
logging.info(" > Restoring Model...")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
@ -242,20 +256,20 @@ class TrainerTTS:
|
||||||
|
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group["lr"] = self.config.lr
|
group["lr"] = self.config.lr
|
||||||
logging.info(" > Model restored from step %d" % checkpoint["step"],
|
print(" > Model restored from step %d" % checkpoint["step"], )
|
||||||
flush=True)
|
|
||||||
restore_step = checkpoint["step"]
|
restore_step = checkpoint["step"]
|
||||||
return model, optimizer, scaler, restore_step
|
return model, optimizer, scaler, restore_step
|
||||||
|
|
||||||
def _setup_loader(self, r, ap, is_eval, data_items, verbose,
|
def _get_loader(self, r: int, ap: AudioProcessor, is_eval: bool,
|
||||||
speaker_mapping):
|
data_items: List, verbose: bool,
|
||||||
|
speaker_mapping: Union[Dict, List]) -> DataLoader:
|
||||||
if is_eval and not self.config.run_eval:
|
if is_eval and not self.config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
else:
|
else:
|
||||||
dataset = TTSDataset(
|
dataset = TTSDataset(
|
||||||
outputs_per_step=r,
|
outputs_per_step=r,
|
||||||
text_cleaner=self.config.text_cleaner,
|
text_cleaner=self.config.text_cleaner,
|
||||||
compute_linear_spec= 'tacotron' == self.config.model.lower(),
|
compute_linear_spec=self.config.model.lower() == "tacotron",
|
||||||
meta_data=data_items,
|
meta_data=data_items,
|
||||||
ap=ap,
|
ap=ap,
|
||||||
tp=self.config.characters,
|
tp=self.config.characters,
|
||||||
|
@ -296,17 +310,19 @@ class TrainerTTS:
|
||||||
)
|
)
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
def setup_train_dataloader(self, r, ap, data_items, verbose,
|
def get_train_dataloader(self, r: int, ap: AudioProcessor,
|
||||||
speaker_mapping):
|
data_items: List, verbose: bool,
|
||||||
return self._setup_loader(r, ap, False, data_items, verbose,
|
speaker_mapping: Union[List, Dict]) -> DataLoader:
|
||||||
speaker_mapping)
|
return self._get_loader(r, ap, False, data_items, verbose,
|
||||||
|
speaker_mapping)
|
||||||
|
|
||||||
def setup_eval_dataloder(self, r, ap, data_items, verbose,
|
def get_eval_dataloder(self, r: int, ap: AudioProcessor, data_items: List,
|
||||||
speaker_mapping):
|
verbose: bool,
|
||||||
return self._setup_loader(r, ap, True, data_items, verbose,
|
speaker_mapping: Union[List, Dict]) -> DataLoader:
|
||||||
speaker_mapping)
|
return self._get_loader(r, ap, True, data_items, verbose,
|
||||||
|
speaker_mapping)
|
||||||
|
|
||||||
def format_batch(self, batch):
|
def format_batch(self, batch: List) -> Dict:
|
||||||
# setup input batch
|
# setup input batch
|
||||||
text_input = batch[0]
|
text_input = batch[0]
|
||||||
text_lengths = batch[1]
|
text_lengths = batch[1]
|
||||||
|
@ -401,7 +417,8 @@ class TrainerTTS:
|
||||||
"item_idx": item_idx
|
"item_idx": item_idx
|
||||||
}
|
}
|
||||||
|
|
||||||
def train_step(self, batch, batch_n_steps, step, loader_start_time):
|
def train_step(self, batch: Dict, batch_n_steps: int, step: int,
|
||||||
|
loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||||
self.on_train_step_start()
|
self.on_train_step_start()
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
@ -515,7 +532,7 @@ class TrainerTTS:
|
||||||
self.on_train_step_end()
|
self.on_train_step_end()
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def train_epoch(self):
|
def train_epoch(self) -> None:
|
||||||
self.model.train()
|
self.model.train()
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
if self.use_cuda:
|
if self.use_cuda:
|
||||||
|
@ -541,7 +558,7 @@ class TrainerTTS:
|
||||||
self.tb_logger.tb_model_weights(self.model,
|
self.tb_logger.tb_model_weights(self.model,
|
||||||
self.total_steps_done)
|
self.total_steps_done)
|
||||||
|
|
||||||
def eval_step(self, batch, step):
|
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
|
@ -572,17 +589,11 @@ class TrainerTTS:
|
||||||
self.keep_avg_eval.avg_values)
|
self.keep_avg_eval.avg_values)
|
||||||
return outputs, loss_dict
|
return outputs, loss_dict
|
||||||
|
|
||||||
def eval_epoch(self):
|
def eval_epoch(self) -> None:
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
if self.use_cuda:
|
|
||||||
batch_num_steps = int(
|
|
||||||
len(self.train_loader.dataset) /
|
|
||||||
(self.config.batch_size * self.num_gpus))
|
|
||||||
else:
|
|
||||||
batch_num_steps = int(
|
|
||||||
len(self.train_loader.dataset) / self.config.batch_size)
|
|
||||||
self.c_logger.print_eval_start()
|
self.c_logger.print_eval_start()
|
||||||
loader_start_time = time.time()
|
loader_start_time = time.time()
|
||||||
|
batch = None
|
||||||
for cur_step, batch in enumerate(self.eval_loader):
|
for cur_step, batch in enumerate(self.eval_loader):
|
||||||
# format data
|
# format data
|
||||||
batch = self.format_batch(batch)
|
batch = self.format_batch(batch)
|
||||||
|
@ -597,8 +608,8 @@ class TrainerTTS:
|
||||||
{"EvalAudio": eval_audios},
|
{"EvalAudio": eval_audios},
|
||||||
self.ap.sample_rate)
|
self.ap.sample_rate)
|
||||||
|
|
||||||
def test_run(self, ):
|
def test_run(self, ) -> None:
|
||||||
logging.info(" | > Synthesizing test sentences.")
|
print(" | > Synthesizing test sentences.")
|
||||||
test_audios = {}
|
test_audios = {}
|
||||||
test_figures = {}
|
test_figures = {}
|
||||||
test_sentences = self.config.test_sentences
|
test_sentences = self.config.test_sentences
|
||||||
|
@ -618,9 +629,11 @@ class TrainerTTS:
|
||||||
do_trim_silence=False,
|
do_trim_silence=False,
|
||||||
).values()
|
).values()
|
||||||
|
|
||||||
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
|
file_path = os.path.join(self.output_audio_path,
|
||||||
|
str(self.total_steps_done))
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
file_path = os.path.join(file_path,
|
||||||
|
"TestSentence_{}.wav".format(idx))
|
||||||
self.ap.save_wav(wav, file_path)
|
self.ap.save_wav(wav, file_path)
|
||||||
test_audios["{}-audio".format(idx)] = wav
|
test_audios["{}-audio".format(idx)] = wav
|
||||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
test_figures["{}-prediction".format(idx)] = plot_spectrogram(
|
||||||
|
@ -629,16 +642,17 @@ class TrainerTTS:
|
||||||
alignment, output_fig=False)
|
alignment, output_fig=False)
|
||||||
|
|
||||||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
|
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
|
||||||
self.config.audio["sample_rate"])
|
self.config.audio["sample_rate"])
|
||||||
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
||||||
|
|
||||||
def _get_cond_inputs(self):
|
def _get_cond_inputs(self) -> Dict:
|
||||||
# setup speaker_id
|
# setup speaker_id
|
||||||
speaker_id = 0 if self.config.use_speaker_embedding else None
|
speaker_id = 0 if self.config.use_speaker_embedding else None
|
||||||
# setup x_vector
|
# setup x_vector
|
||||||
x_vector = self.speaker_manager.get_x_vectors_by_speaker(
|
x_vector = (self.speaker_manager.get_x_vectors_by_speaker(
|
||||||
self.speaker_manager.speaker_ids[0]
|
self.speaker_manager.speaker_ids[0])
|
||||||
) if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding else None
|
if self.config.use_external_speaker_embedding_file
|
||||||
|
and self.config.use_speaker_embedding else None)
|
||||||
# setup style_mel
|
# setup style_mel
|
||||||
if self.config.has('gst_style_input'):
|
if self.config.has('gst_style_input'):
|
||||||
style_wav = self.config.gst_style_input
|
style_wav = self.config.gst_style_input
|
||||||
|
@ -647,35 +661,40 @@ class TrainerTTS:
|
||||||
if style_wav is None and 'use_gst' in self.config and self.config.use_gst:
|
if style_wav is None and 'use_gst' in self.config and self.config.use_gst:
|
||||||
# inicialize GST with zero dict.
|
# inicialize GST with zero dict.
|
||||||
style_wav = {}
|
style_wav = {}
|
||||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
print(
|
||||||
|
"WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!"
|
||||||
|
)
|
||||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||||
style_wav[str(i)] = 0
|
style_wav[str(i)] = 0
|
||||||
cond_inputs = {'speaker_id': speaker_id, 'style_wav': style_wav, 'x_vector': x_vector}
|
cond_inputs = {
|
||||||
|
"speaker_id": speaker_id,
|
||||||
|
"style_wav": style_wav,
|
||||||
|
"x_vector": x_vector
|
||||||
|
}
|
||||||
return cond_inputs
|
return cond_inputs
|
||||||
|
|
||||||
def fit(self):
|
def fit(self) -> None:
|
||||||
if self.restore_step != 0 or self.args.best_path:
|
if self.restore_step != 0 or self.args.best_path:
|
||||||
logging.info(" > Restoring best loss from "
|
print(" > Restoring best loss from "
|
||||||
f"{os.path.basename(self.args.best_path)} ...")
|
f"{os.path.basename(self.args.best_path)} ...")
|
||||||
self.best_loss = torch.load(self.args.best_path,
|
self.best_loss = torch.load(self.args.best_path,
|
||||||
map_location="cpu")["model_loss"]
|
map_location="cpu")["model_loss"]
|
||||||
logging.info(
|
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||||
f" > Starting with loaded last best loss {self.best_loss}.")
|
|
||||||
|
|
||||||
# define data loaders
|
# define data loaders
|
||||||
self.train_loader = self.setup_train_dataloader(
|
self.train_loader = self.get_train_dataloader(
|
||||||
self.config.r,
|
self.config.r,
|
||||||
self.ap,
|
self.ap,
|
||||||
self.data_train,
|
self.data_train,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
speaker_mapping=self.speaker_manager.speaker_ids)
|
speaker_mapping=self.speaker_manager.speaker_ids)
|
||||||
self.eval_loader = self.setup_eval_dataloder(
|
self.eval_loader = (self.get_eval_dataloder(
|
||||||
self.config.r,
|
self.config.r,
|
||||||
self.ap,
|
self.ap,
|
||||||
self.data_train,
|
self.data_train,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
speaker_mapping=self.speaker_manager.speaker_ids
|
speaker_mapping=self.speaker_manager.speaker_ids)
|
||||||
) if self.config.run_eval else None
|
if self.config.run_eval else None)
|
||||||
|
|
||||||
self.total_steps_done = self.restore_step
|
self.total_steps_done = self.restore_step
|
||||||
|
|
||||||
|
@ -697,10 +716,10 @@ class TrainerTTS:
|
||||||
self.save_best_model()
|
self.save_best_model()
|
||||||
self.on_epoch_end()
|
self.on_epoch_end()
|
||||||
|
|
||||||
def save_best_model(self):
|
def save_best_model(self) -> None:
|
||||||
self.best_loss = save_best_model(
|
self.best_loss = save_best_model(
|
||||||
self.keep_avg_eval['avg_loss']
|
self.keep_avg_eval["avg_loss"]
|
||||||
if self.keep_avg_eval else self.keep_avg_train['avg_loss'],
|
if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
||||||
self.best_loss,
|
self.best_loss,
|
||||||
self.model,
|
self.model,
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
|
@ -715,8 +734,16 @@ class TrainerTTS:
|
||||||
if self.config.mixed_precision else None,
|
if self.config.mixed_precision else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_epoch_start(self):
|
@staticmethod
|
||||||
if hasattr(self.model, 'on_epoch_start'):
|
def _setup_logger_config(log_file: str) -> None:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="",
|
||||||
|
handlers=[logging.FileHandler(log_file),
|
||||||
|
logging.StreamHandler()])
|
||||||
|
|
||||||
|
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
|
||||||
|
if hasattr(self.model, "on_epoch_start"):
|
||||||
self.model.on_epoch_start(self)
|
self.model.on_epoch_start(self)
|
||||||
|
|
||||||
if hasattr(self.criterion, "on_epoch_start"):
|
if hasattr(self.criterion, "on_epoch_start"):
|
||||||
|
@ -725,8 +752,8 @@ class TrainerTTS:
|
||||||
if hasattr(self.optimizer, "on_epoch_start"):
|
if hasattr(self.optimizer, "on_epoch_start"):
|
||||||
self.optimizer.on_epoch_start(self)
|
self.optimizer.on_epoch_start(self)
|
||||||
|
|
||||||
def on_epoch_end(self):
|
def on_epoch_end(self) -> None: # pylint: disable=no-self-use
|
||||||
if hasattr(self.model, "on_epoch_start"):
|
if hasattr(self.model, "on_epoch_end"):
|
||||||
self.model.on_epoch_end(self)
|
self.model.on_epoch_end(self)
|
||||||
|
|
||||||
if hasattr(self.criterion, "on_epoch_end"):
|
if hasattr(self.criterion, "on_epoch_end"):
|
||||||
|
@ -735,8 +762,8 @@ class TrainerTTS:
|
||||||
if hasattr(self.optimizer, "on_epoch_end"):
|
if hasattr(self.optimizer, "on_epoch_end"):
|
||||||
self.optimizer.on_epoch_end(self)
|
self.optimizer.on_epoch_end(self)
|
||||||
|
|
||||||
def on_train_step_start(self):
|
def on_train_step_start(self) -> None: # pylint: disable=no-self-use
|
||||||
if hasattr(self.model, "on_epoch_start"):
|
if hasattr(self.model, "on_train_step_start"):
|
||||||
self.model.on_train_step_start(self)
|
self.model.on_train_step_start(self)
|
||||||
|
|
||||||
if hasattr(self.criterion, "on_train_step_start"):
|
if hasattr(self.criterion, "on_train_step_start"):
|
||||||
|
@ -745,7 +772,7 @@ class TrainerTTS:
|
||||||
if hasattr(self.optimizer, "on_train_step_start"):
|
if hasattr(self.optimizer, "on_train_step_start"):
|
||||||
self.optimizer.on_train_step_start(self)
|
self.optimizer.on_train_step_start(self)
|
||||||
|
|
||||||
def on_train_step_end(self):
|
def on_train_step_end(self) -> None: # pylint: disable=no-self-use
|
||||||
if hasattr(self.model, "on_train_step_end"):
|
if hasattr(self.model, "on_train_step_end"):
|
||||||
self.model.on_train_step_end(self)
|
self.model.on_train_step_end(self)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue