mirror of https://github.com/coqui-ai/TTS.git
adjust `distribute.py` for the `train_tts.py`
This commit is contained in:
parent
fdfb18d230
commit
03494ad642
|
@ -30,7 +30,7 @@ def main():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
|
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args, unargs = parser.parse_known_args()
|
||||||
|
|
||||||
num_gpus = torch.cuda.device_count()
|
num_gpus = torch.cuda.device_count()
|
||||||
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
||||||
|
@ -42,6 +42,7 @@ def main():
|
||||||
command.append("--restore_path={}".format(args.restore_path))
|
command.append("--restore_path={}".format(args.restore_path))
|
||||||
command.append("--config_path={}".format(args.config_path))
|
command.append("--config_path={}".format(args.config_path))
|
||||||
command.append("--group_id=group_{}".format(group_id))
|
command.append("--group_id=group_{}".format(group_id))
|
||||||
|
command += unargs
|
||||||
command.append("")
|
command.append("")
|
||||||
|
|
||||||
# run processes
|
# run processes
|
||||||
|
|
|
@ -65,15 +65,19 @@ class TrainerTTS:
|
||||||
self,
|
self,
|
||||||
args: Union[Coqpit, Namespace],
|
args: Union[Coqpit, Namespace],
|
||||||
config: Coqpit,
|
config: Coqpit,
|
||||||
c_logger: ConsoleLogger,
|
c_logger: ConsoleLogger = None,
|
||||||
tb_logger: TensorboardLogger,
|
tb_logger: TensorboardLogger = None,
|
||||||
model: nn.Module = None,
|
model: nn.Module = None,
|
||||||
output_path: str = None,
|
output_path: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.config = config
|
self.config = config
|
||||||
self.c_logger = c_logger
|
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||||
self.tb_logger = tb_logger
|
if tb_logger is None:
|
||||||
|
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
|
||||||
|
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||||
|
else:
|
||||||
|
self.tb_logger = tb_logger
|
||||||
self.output_path = output_path
|
self.output_path = output_path
|
||||||
|
|
||||||
self.total_steps_done = 0
|
self.total_steps_done = 0
|
||||||
|
@ -117,20 +121,20 @@ class TrainerTTS:
|
||||||
# setup criterion
|
# setup criterion
|
||||||
self.criterion = self.get_criterion(self.config)
|
self.criterion = self.get_criterion(self.config)
|
||||||
|
|
||||||
if self.use_cuda:
|
|
||||||
self.model.cuda()
|
|
||||||
self.criterion.cuda()
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
init_distributed(
|
init_distributed(
|
||||||
args.rank,
|
args.rank,
|
||||||
self.num_gpus,
|
self.num_gpus,
|
||||||
args.group_id,
|
args.group_id,
|
||||||
self.config.distributed["backend"],
|
self.config.distributed_backend,
|
||||||
self.config.distributed["url"],
|
self.config.distributed_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.use_cuda:
|
||||||
|
self.model.cuda()
|
||||||
|
self.criterion.cuda()
|
||||||
|
|
||||||
# scalers for mixed precision training
|
# scalers for mixed precision training
|
||||||
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
|
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
|
||||||
|
|
||||||
|
@ -147,7 +151,7 @@ class TrainerTTS:
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
if self.num_gpus > 1:
|
if self.num_gpus > 1:
|
||||||
self.model = DDP_th(self.model, device_ids=[args.rank])
|
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
||||||
|
|
||||||
# count model size
|
# count model size
|
||||||
num_params = count_parameters(self.model)
|
num_params = count_parameters(self.model)
|
||||||
|
@ -377,6 +381,11 @@ class TrainerTTS:
|
||||||
"item_idx": item_idx,
|
"item_idx": item_idx,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||||
|
if hasattr(self.model, "module"):
|
||||||
|
return self.model.module.train_step(batch, criterion)
|
||||||
|
return self.model.train_step(batch, criterion)
|
||||||
|
|
||||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||||
self.on_train_step_start()
|
self.on_train_step_start()
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
@ -389,7 +398,7 @@ class TrainerTTS:
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||||
outputs, loss_dict = self.model.train_step(batch, self.criterion)
|
outputs, loss_dict = self._train_step(batch, self.criterion)
|
||||||
|
|
||||||
# check nan loss
|
# check nan loss
|
||||||
if torch.isnan(loss_dict["loss"]).any():
|
if torch.isnan(loss_dict["loss"]).any():
|
||||||
|
@ -473,7 +482,10 @@ class TrainerTTS:
|
||||||
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||||
)
|
)
|
||||||
# training visualizations
|
# training visualizations
|
||||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
if hasattr(self.model, "module"):
|
||||||
|
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
|
||||||
|
else:
|
||||||
|
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||||
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
|
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
|
||||||
self.total_steps_done += 1
|
self.total_steps_done += 1
|
||||||
|
@ -500,12 +512,17 @@ class TrainerTTS:
|
||||||
if self.config.tb_model_param_stats:
|
if self.config.tb_model_param_stats:
|
||||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||||
|
|
||||||
|
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
|
||||||
|
if hasattr(self.model, "module"):
|
||||||
|
return self.model.module.eval_step(batch, self.criterion)
|
||||||
|
return self.model.eval_step(batch, self.criterion)
|
||||||
|
|
||||||
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
step_start_time = time.time()
|
step_start_time = time.time()
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||||
outputs, loss_dict = self.model.eval_step(batch, self.criterion)
|
outputs, loss_dict = self._eval_step(batch)
|
||||||
|
|
||||||
step_time = time.time() - step_start_time
|
step_time = time.time() - step_start_time
|
||||||
|
|
||||||
|
@ -542,7 +559,10 @@ class TrainerTTS:
|
||||||
outputs, _ = self.eval_step(batch, cur_step)
|
outputs, _ = self.eval_step(batch, cur_step)
|
||||||
# Plot epoch stats and samples from the last batch.
|
# Plot epoch stats and samples from the last batch.
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
if hasattr(self.model, "module"):
|
||||||
|
figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs)
|
||||||
|
else:
|
||||||
|
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
||||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||||
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
||||||
|
|
||||||
|
@ -642,7 +662,7 @@ class TrainerTTS:
|
||||||
self.train_epoch()
|
self.train_epoch()
|
||||||
if self.config.run_eval:
|
if self.config.run_eval:
|
||||||
self.eval_epoch()
|
self.eval_epoch()
|
||||||
if epoch >= self.config.test_delay_epochs:
|
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
|
||||||
self.test_run()
|
self.test_run()
|
||||||
self.c_logger.print_epoch_end(
|
self.c_logger.print_epoch_end(
|
||||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||||
|
|
|
@ -87,7 +87,11 @@ def run_model_torch(
|
||||||
Dict: model outputs.
|
Dict: model outputs.
|
||||||
"""
|
"""
|
||||||
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
|
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
|
||||||
outputs = model.inference(
|
if hasattr(model, "module"):
|
||||||
|
_func = model.module.inference
|
||||||
|
else:
|
||||||
|
_func = model.inference
|
||||||
|
outputs = _func(
|
||||||
inputs,
|
inputs,
|
||||||
cond_input={
|
cond_input={
|
||||||
"x_lengths": input_lengths,
|
"x_lengths": input_lengths,
|
||||||
|
|
|
@ -36,12 +36,14 @@
|
||||||
"gst_num_heads": 4,
|
"gst_num_heads": 4,
|
||||||
"gst_num_style_tokens": 10
|
"gst_num_style_tokens": 10
|
||||||
},
|
},
|
||||||
|
"distributed_backend": "gloo",
|
||||||
|
"distributed_url": "tcp:\/\/localhost:54321",
|
||||||
"model": "Tacotron2",
|
"model": "Tacotron2",
|
||||||
"run_name": "ljspeech-ddc",
|
"run_name": "ljspeech-ddc",
|
||||||
"run_description": "tacotron2 with double decoder consistency.",
|
"run_description": "tacotron2 with double decoder consistency.",
|
||||||
"batch_size": 64,
|
"batch_size": 64,
|
||||||
"eval_batch_size": 16,
|
"eval_batch_size": 16,
|
||||||
"mixed_precision": true,
|
"mixed_precision": false,
|
||||||
"loss_masking": true,
|
"loss_masking": true,
|
||||||
"decoder_loss_alpha": 0.25,
|
"decoder_loss_alpha": 0.25,
|
||||||
"postnet_loss_alpha": 0.25,
|
"postnet_loss_alpha": 0.25,
|
||||||
|
@ -54,6 +56,7 @@
|
||||||
"run_eval": true,
|
"run_eval": true,
|
||||||
"test_delay_epochs": 10,
|
"test_delay_epochs": 10,
|
||||||
"test_sentences_file": null,
|
"test_sentences_file": null,
|
||||||
|
"max_decoder_steps": 50,
|
||||||
"noam_schedule": true,
|
"noam_schedule": true,
|
||||||
"grad_clip": 0.05,
|
"grad_clip": 0.05,
|
||||||
"epochs": 1000,
|
"epochs": 1000,
|
||||||
|
|
Loading…
Reference in New Issue