mirror of https://github.com/coqui-ai/TTS.git
Update `trainer.py`
Fix multi-speaker initialization of models. Add changes for end2end`tts` models.
This commit is contained in:
parent
b7f387b3dd
commit
bf562cf437
147
TTS/trainer.py
147
TTS/trainer.py
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
|
@ -42,6 +43,8 @@ from TTS.utils.trainer_utils import get_optimizer, get_scheduler, is_apex_availa
|
||||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||||
|
|
||||||
|
multiprocessing.set_start_method("fork")
|
||||||
|
|
||||||
if platform.system() != "Windows":
|
if platform.system() != "Windows":
|
||||||
# https://github.com/pytorch/pytorch/issues/973
|
# https://github.com/pytorch/pytorch/issues/973
|
||||||
import resource
|
import resource
|
||||||
|
@ -149,7 +152,6 @@ class Trainer:
|
||||||
|
|
||||||
# set and initialize Pytorch runtime
|
# set and initialize Pytorch runtime
|
||||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
# parse config from console arguments
|
# parse config from console arguments
|
||||||
config, output_path, _, c_logger, tb_logger = process_args(args)
|
config, output_path, _, c_logger, tb_logger = process_args(args)
|
||||||
|
@ -184,7 +186,7 @@ class Trainer:
|
||||||
# init audio processor
|
# init audio processor
|
||||||
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
self.ap = AudioProcessor(**self.config.audio.to_dict())
|
||||||
|
|
||||||
# load dataset samples
|
# load data samples
|
||||||
# TODO: refactor this
|
# TODO: refactor this
|
||||||
if "datasets" in self.config:
|
if "datasets" in self.config:
|
||||||
# load data for `tts` models
|
# load data for `tts` models
|
||||||
|
@ -205,6 +207,10 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
self.model = self.get_model(self.config)
|
self.model = self.get_model(self.config)
|
||||||
|
|
||||||
|
# init multispeaker settings of the model
|
||||||
|
if hasattr(self.model, "init_multispeaker"):
|
||||||
|
self.model.init_multispeaker(self.config, self.data_train + self.data_eval)
|
||||||
|
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.get_criterion(self.model)
|
self.criterion = self.get_criterion(self.model)
|
||||||
|
|
||||||
|
@ -274,9 +280,9 @@ class Trainer:
|
||||||
"""
|
"""
|
||||||
# TODO: better model setup
|
# TODO: better model setup
|
||||||
try:
|
try:
|
||||||
model = setup_tts_model(config)
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
model = setup_vocoder_model(config)
|
model = setup_vocoder_model(config)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
model = setup_tts_model(config)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def restore_model(
|
def restore_model(
|
||||||
|
@ -417,7 +423,7 @@ class Trainer:
|
||||||
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
|
scheduler: Union[torch.optim.lr_scheduler._LRScheduler, List], # pylint: disable=protected-access
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
optimizer_idx: int = None,
|
optimizer_idx: int = None,
|
||||||
) -> Tuple[Dict, Dict, int, torch.Tensor]:
|
) -> Tuple[Dict, Dict, int]:
|
||||||
"""Perform a forward - backward pass and run the optimizer.
|
"""Perform a forward - backward pass and run the optimizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -426,7 +432,7 @@ class Trainer:
|
||||||
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
|
optimizer (Union[nn.optim.Optimizer, List]): Model's optimizer. If it is a list then, `optimizer_idx` must be defined to indicate the optimizer in use.
|
||||||
scaler (AMPScaler): AMP scaler.
|
scaler (AMPScaler): AMP scaler.
|
||||||
criterion (nn.Module): Model's criterion.
|
criterion (nn.Module): Model's criterion.
|
||||||
scheduler (Union[torch.optim.lr_scheduler._LRScheduler, List]): LR scheduler used by the optimizer.
|
scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler used by the optimizer.
|
||||||
config (Coqpit): Model config.
|
config (Coqpit): Model config.
|
||||||
optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
|
optimizer_idx (int, optional): Target optimizer being used. Defaults to None.
|
||||||
|
|
||||||
|
@ -436,6 +442,7 @@ class Trainer:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
|
Tuple[Dict, Dict, int, torch.Tensor]: model outputs, losses, step time and gradient norm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
# zero-out optimizer
|
# zero-out optimizer
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
@ -448,11 +455,11 @@ class Trainer:
|
||||||
# skip the rest
|
# skip the rest
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
step_time = time.time() - step_start_time
|
step_time = time.time() - step_start_time
|
||||||
return None, {}, step_time, 0
|
return None, {}, step_time
|
||||||
|
|
||||||
# check nan loss
|
# check nan loss
|
||||||
if torch.isnan(loss_dict["loss"]).any():
|
if torch.isnan(loss_dict["loss"]).any():
|
||||||
raise RuntimeError(f"Detected NaN loss at step {self.total_steps_done}.")
|
raise RuntimeError(f" > Detected NaN loss - {loss_dict}.")
|
||||||
|
|
||||||
# set gradient clipping threshold
|
# set gradient clipping threshold
|
||||||
if "grad_clip" in config and config.grad_clip is not None:
|
if "grad_clip" in config and config.grad_clip is not None:
|
||||||
|
@ -463,7 +470,6 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
grad_clip = 0.0 # meaning no gradient clipping
|
grad_clip = 0.0 # meaning no gradient clipping
|
||||||
|
|
||||||
# TODO: compute grad norm
|
|
||||||
if grad_clip <= 0:
|
if grad_clip <= 0:
|
||||||
grad_norm = 0
|
grad_norm = 0
|
||||||
|
|
||||||
|
@ -474,15 +480,17 @@ class Trainer:
|
||||||
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
with amp.scale_loss(loss_dict["loss"], optimizer) as scaled_loss:
|
||||||
scaled_loss.backward()
|
scaled_loss.backward()
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
amp.master_params(optimizer),
|
amp.master_params(optimizer), grad_clip, error_if_nonfinite=False
|
||||||
grad_clip,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# model optimizer step in mixed precision mode
|
# model optimizer step in mixed precision mode
|
||||||
scaler.scale(loss_dict["loss"]).backward()
|
scaler.scale(loss_dict["loss"]).backward()
|
||||||
scaler.unscale_(optimizer)
|
|
||||||
if grad_clip > 0:
|
if grad_clip > 0:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
scaler.unscale_(optimizer)
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||||
|
# pytorch skips the step when the norm is 0. So ignore the norm value when it is NaN
|
||||||
|
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
|
||||||
|
grad_norm = 0
|
||||||
scale_prev = scaler.get_scale()
|
scale_prev = scaler.get_scale()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
@ -491,13 +499,13 @@ class Trainer:
|
||||||
# main model optimizer step
|
# main model optimizer step
|
||||||
loss_dict["loss"].backward()
|
loss_dict["loss"].backward()
|
||||||
if grad_clip > 0:
|
if grad_clip > 0:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip, error_if_nonfinite=False)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
step_time = time.time() - step_start_time
|
step_time = time.time() - step_start_time
|
||||||
|
|
||||||
# setup lr
|
# setup lr
|
||||||
if scheduler is not None and update_lr_scheduler:
|
if scheduler is not None and update_lr_scheduler and not self.config.scheduler_after_epoch:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# detach losses
|
# detach losses
|
||||||
|
@ -505,7 +513,9 @@ class Trainer:
|
||||||
if optimizer_idx is not None:
|
if optimizer_idx is not None:
|
||||||
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
|
loss_dict[f"loss_{optimizer_idx}"] = loss_dict.pop("loss")
|
||||||
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
|
loss_dict[f"grad_norm_{optimizer_idx}"] = grad_norm
|
||||||
return outputs, loss_dict, step_time, grad_norm
|
else:
|
||||||
|
loss_dict["grad_norm"] = grad_norm
|
||||||
|
return outputs, loss_dict, step_time
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
def _detach_loss_dict(loss_dict: Dict) -> Dict:
|
||||||
|
@ -544,11 +554,10 @@ class Trainer:
|
||||||
|
|
||||||
# conteainers to hold model outputs and losses for each optimizer.
|
# conteainers to hold model outputs and losses for each optimizer.
|
||||||
outputs_per_optimizer = None
|
outputs_per_optimizer = None
|
||||||
log_dict = {}
|
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
if not isinstance(self.optimizer, list):
|
if not isinstance(self.optimizer, list):
|
||||||
# training with a single optimizer
|
# training with a single optimizer
|
||||||
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
|
outputs, loss_dict_new, step_time = self._optimize(
|
||||||
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
|
batch, self.model, self.optimizer, self.scaler, self.criterion, self.scheduler, self.config
|
||||||
)
|
)
|
||||||
loss_dict.update(loss_dict_new)
|
loss_dict.update(loss_dict_new)
|
||||||
|
@ -560,25 +569,36 @@ class Trainer:
|
||||||
criterion = self.criterion
|
criterion = self.criterion
|
||||||
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
scaler = self.scaler[idx] if self.use_amp_scaler else None
|
||||||
scheduler = self.scheduler[idx]
|
scheduler = self.scheduler[idx]
|
||||||
outputs, loss_dict_new, step_time, grad_norm = self._optimize(
|
outputs, loss_dict_new, step_time = self._optimize(
|
||||||
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
batch, self.model, optimizer, scaler, criterion, scheduler, self.config, idx
|
||||||
)
|
)
|
||||||
# skip the rest if the model returns None
|
# skip the rest if the model returns None
|
||||||
total_step_time += step_time
|
total_step_time += step_time
|
||||||
outputs_per_optimizer[idx] = outputs
|
outputs_per_optimizer[idx] = outputs
|
||||||
|
# merge loss_dicts from each optimizer
|
||||||
|
# rename duplicates with the optimizer idx
|
||||||
# if None, model skipped this optimizer
|
# if None, model skipped this optimizer
|
||||||
if loss_dict_new is not None:
|
if loss_dict_new is not None:
|
||||||
loss_dict.update(loss_dict_new)
|
for k, v in loss_dict_new.items():
|
||||||
|
if k in loss_dict:
|
||||||
|
loss_dict[f"{k}-{idx}"] = v
|
||||||
|
else:
|
||||||
|
loss_dict[k] = v
|
||||||
|
step_time = total_step_time
|
||||||
outputs = outputs_per_optimizer
|
outputs = outputs_per_optimizer
|
||||||
|
|
||||||
# update avg stats
|
# update avg runtime stats
|
||||||
keep_avg_update = dict()
|
keep_avg_update = dict()
|
||||||
for key, value in log_dict.items():
|
|
||||||
keep_avg_update["avg_" + key] = value
|
|
||||||
keep_avg_update["avg_loader_time"] = loader_time
|
keep_avg_update["avg_loader_time"] = loader_time
|
||||||
keep_avg_update["avg_step_time"] = step_time
|
keep_avg_update["avg_step_time"] = step_time
|
||||||
self.keep_avg_train.update_values(keep_avg_update)
|
self.keep_avg_train.update_values(keep_avg_update)
|
||||||
|
|
||||||
|
# update avg loss stats
|
||||||
|
update_eval_values = dict()
|
||||||
|
for key, value in loss_dict.items():
|
||||||
|
update_eval_values["avg_" + key] = value
|
||||||
|
self.keep_avg_train.update_values(update_eval_values)
|
||||||
|
|
||||||
# print training progress
|
# print training progress
|
||||||
if self.total_steps_done % self.config.print_step == 0:
|
if self.total_steps_done % self.config.print_step == 0:
|
||||||
# log learning rates
|
# log learning rates
|
||||||
|
@ -590,33 +610,27 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
current_lr = self.optimizer.param_groups[0]["lr"]
|
current_lr = self.optimizer.param_groups[0]["lr"]
|
||||||
lrs = {"current_lr": current_lr}
|
lrs = {"current_lr": current_lr}
|
||||||
log_dict.update(lrs)
|
|
||||||
if grad_norm > 0:
|
|
||||||
log_dict.update({"grad_norm": grad_norm})
|
|
||||||
# log run-time stats
|
# log run-time stats
|
||||||
log_dict.update(
|
loss_dict.update(
|
||||||
{
|
{
|
||||||
"step_time": round(step_time, 4),
|
"step_time": round(step_time, 4),
|
||||||
"loader_time": round(loader_time, 4),
|
"loader_time": round(loader_time, 4),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.c_logger.print_train_step(
|
self.c_logger.print_train_step(
|
||||||
batch_n_steps, step, self.total_steps_done, log_dict, loss_dict, self.keep_avg_train.avg_values
|
batch_n_steps, step, self.total_steps_done, loss_dict, self.keep_avg_train.avg_values
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
# Plot Training Iter Stats
|
# Plot Training Iter Stats
|
||||||
# reduce TB load and don't log every step
|
# reduce TB load and don't log every step
|
||||||
if self.total_steps_done % self.config.tb_plot_step == 0:
|
if self.total_steps_done % self.config.tb_plot_step == 0:
|
||||||
iter_stats = log_dict
|
self.tb_logger.tb_train_step_stats(self.total_steps_done, loss_dict)
|
||||||
iter_stats.update(loss_dict)
|
|
||||||
self.tb_logger.tb_train_step_stats(self.total_steps_done, iter_stats)
|
|
||||||
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
if self.total_steps_done % self.config.save_step == 0 and self.total_steps_done != 0:
|
||||||
if self.config.checkpoint:
|
if self.config.checkpoint:
|
||||||
# checkpoint the model
|
# checkpoint the model
|
||||||
model_loss = (
|
target_avg_loss = self._pick_target_avg_loss(self.keep_avg_train)
|
||||||
loss_dict[self.config.target_loss] if "target_loss" in self.config else loss_dict["loss"]
|
|
||||||
)
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
self.config,
|
self.config,
|
||||||
self.model,
|
self.model,
|
||||||
|
@ -625,7 +639,7 @@ class Trainer:
|
||||||
self.total_steps_done,
|
self.total_steps_done,
|
||||||
self.epochs_done,
|
self.epochs_done,
|
||||||
self.output_path,
|
self.output_path,
|
||||||
model_loss=model_loss,
|
model_loss=target_avg_loss,
|
||||||
)
|
)
|
||||||
# training visualizations
|
# training visualizations
|
||||||
figures, audios = None, None
|
figures, audios = None, None
|
||||||
|
@ -666,6 +680,14 @@ class Trainer:
|
||||||
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
self.tb_logger.tb_train_epoch_stats(self.total_steps_done, epoch_stats)
|
||||||
if self.config.tb_model_param_stats:
|
if self.config.tb_model_param_stats:
|
||||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||||
|
# scheduler step after the epoch
|
||||||
|
if self.scheduler is not None and self.config.scheduler_after_epoch:
|
||||||
|
if isinstance(self.scheduler, list):
|
||||||
|
for scheduler in self.scheduler:
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler.step()
|
||||||
|
else:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _model_eval_step(
|
def _model_eval_step(
|
||||||
|
@ -701,19 +723,22 @@ class Trainer:
|
||||||
Tuple[Dict, Dict]: Model outputs and losses.
|
Tuple[Dict, Dict]: Model outputs and losses.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs_per_optimizer = None
|
outputs = []
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
if not isinstance(self.optimizer, list):
|
if not isinstance(self.optimizer, list):
|
||||||
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
|
outputs, loss_dict = self._model_eval_step(batch, self.model, self.criterion)
|
||||||
else:
|
else:
|
||||||
outputs_per_optimizer = [None] * len(self.optimizer)
|
outputs = [None] * len(self.optimizer)
|
||||||
for idx, _ in enumerate(self.optimizer):
|
for idx, _ in enumerate(self.optimizer):
|
||||||
criterion = self.criterion
|
criterion = self.criterion
|
||||||
outputs, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
|
outputs_, loss_dict_new = self._model_eval_step(batch, self.model, criterion, idx)
|
||||||
outputs_per_optimizer[idx] = outputs
|
outputs[idx] = outputs_
|
||||||
|
|
||||||
if loss_dict_new is not None:
|
if loss_dict_new is not None:
|
||||||
|
loss_dict_new[f"loss_{idx}"] = loss_dict_new.pop("loss")
|
||||||
loss_dict.update(loss_dict_new)
|
loss_dict.update(loss_dict_new)
|
||||||
outputs = outputs_per_optimizer
|
|
||||||
|
loss_dict = self._detach_loss_dict(loss_dict)
|
||||||
|
|
||||||
# update avg stats
|
# update avg stats
|
||||||
update_eval_values = dict()
|
update_eval_values = dict()
|
||||||
|
@ -764,6 +789,13 @@ class Trainer:
|
||||||
"""Run test and log the results. Test run must be defined by the model.
|
"""Run test and log the results. Test run must be defined by the model.
|
||||||
Model must return figures and audios to be logged by the Tensorboard."""
|
Model must return figures and audios to be logged by the Tensorboard."""
|
||||||
if hasattr(self.model, "test_run"):
|
if hasattr(self.model, "test_run"):
|
||||||
|
if self.eval_loader is None:
|
||||||
|
self.eval_loader = self.get_eval_dataloader(
|
||||||
|
self.ap,
|
||||||
|
self.data_eval,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
if hasattr(self.eval_loader.dataset, "load_test_samples"):
|
||||||
samples = self.eval_loader.dataset.load_test_samples(1)
|
samples = self.eval_loader.dataset.load_test_samples(1)
|
||||||
figures, audios = self.model.test_run(self.ap, samples, None)
|
figures, audios = self.model.test_run(self.ap, samples, None)
|
||||||
|
@ -816,10 +848,33 @@ class Trainer:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
def _pick_target_avg_loss(self, keep_avg_target: KeepAverage) -> Dict:
|
||||||
|
"""Pick the target loss to compare models"""
|
||||||
|
target_avg_loss = None
|
||||||
|
|
||||||
|
# return if target loss defined in the model config
|
||||||
|
if "target_loss" in self.config and self.config.target_loss:
|
||||||
|
return keep_avg_target[f"avg_{self.config.target_loss}"]
|
||||||
|
|
||||||
|
# take the average of loss_{optimizer_idx} as the target loss when there are multiple optimizers
|
||||||
|
if isinstance(self.optimizer, list):
|
||||||
|
target_avg_loss = 0
|
||||||
|
for idx in range(len(self.optimizer)):
|
||||||
|
target_avg_loss += keep_avg_target[f"avg_loss_{idx}"]
|
||||||
|
target_avg_loss /= len(self.optimizer)
|
||||||
|
else:
|
||||||
|
target_avg_loss = keep_avg_target["avg_loss"]
|
||||||
|
return target_avg_loss
|
||||||
|
|
||||||
def save_best_model(self) -> None:
|
def save_best_model(self) -> None:
|
||||||
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
"""Save the best model. It only saves if the current target loss is smaller then the previous."""
|
||||||
|
|
||||||
|
# set the target loss to choose the best model
|
||||||
|
target_loss_dict = self._pick_target_avg_loss(self.keep_avg_eval if self.keep_avg_eval else self.keep_avg_train)
|
||||||
|
|
||||||
|
# save the model and update the best_loss
|
||||||
self.best_loss = save_best_model(
|
self.best_loss = save_best_model(
|
||||||
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
target_loss_dict,
|
||||||
self.best_loss,
|
self.best_loss,
|
||||||
self.config,
|
self.config,
|
||||||
self.model,
|
self.model,
|
||||||
|
@ -927,7 +982,7 @@ class Trainer:
|
||||||
return criterion
|
return criterion
|
||||||
|
|
||||||
|
|
||||||
def init_arguments():
|
def getarguments():
|
||||||
train_config = TrainingArgs()
|
train_config = TrainingArgs()
|
||||||
parser = train_config.init_argparse(arg_prefix="")
|
parser = train_config.init_argparse(arg_prefix="")
|
||||||
return parser
|
return parser
|
||||||
|
@ -1054,7 +1109,7 @@ def process_args(args, config=None):
|
||||||
# if model characters are not set in the config file
|
# if model characters are not set in the config file
|
||||||
# save the default set to the config file for future
|
# save the default set to the config file for future
|
||||||
# compatibility.
|
# compatibility.
|
||||||
if config.has("characters_config"):
|
if config.has("characters") and config.characters is None:
|
||||||
used_characters = parse_symbols()
|
used_characters = parse_symbols()
|
||||||
new_fields["characters"] = used_characters
|
new_fields["characters"] = used_characters
|
||||||
copy_model_files(config, experiment_path, new_fields)
|
copy_model_files(config, experiment_path, new_fields)
|
||||||
|
@ -1065,6 +1120,12 @@ def process_args(args, config=None):
|
||||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||||
|
|
||||||
|
|
||||||
|
def init_arguments():
|
||||||
|
train_config = TrainingArgs()
|
||||||
|
parser = train_config.init_argparse(arg_prefix="")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
def init_training(argv: Union[List, Coqpit], config: Coqpit = None):
|
||||||
"""Initialization of a training run."""
|
"""Initialization of a training run."""
|
||||||
if isinstance(argv, Coqpit):
|
if isinstance(argv, Coqpit):
|
||||||
|
|
Loading…
Reference in New Issue