From 03494ad6428868ece287ca922715add4f3bd0fe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Sun, 6 Jun 2021 15:20:17 +0200 Subject: [PATCH] adjust `distribute.py` for the `train_tts.py` --- TTS/bin/distribute.py | 3 +- TTS/trainer.py | 52 +++++++++++++------ TTS/tts/utils/synthesis.py | 6 ++- .../ljspeech/tacotron2-DDC/tacotron2-DDC.json | 7 ++- 4 files changed, 48 insertions(+), 20 deletions(-) diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index ea43f88b..20d4bb20 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -30,7 +30,7 @@ def main(): parser.add_argument( "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv ) - args = parser.parse_args() + args, unargs = parser.parse_known_args() num_gpus = torch.cuda.device_count() group_id = time.strftime("%Y_%m_%d-%H%M%S") @@ -42,6 +42,7 @@ def main(): command.append("--restore_path={}".format(args.restore_path)) command.append("--config_path={}".format(args.config_path)) command.append("--group_id=group_{}".format(group_id)) + command += unargs command.append("") # run processes diff --git a/TTS/trainer.py b/TTS/trainer.py index 9fe2f108..76c741b1 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -65,15 +65,19 @@ class TrainerTTS: self, args: Union[Coqpit, Namespace], config: Coqpit, - c_logger: ConsoleLogger, - tb_logger: TensorboardLogger, + c_logger: ConsoleLogger = None, + tb_logger: TensorboardLogger = None, model: nn.Module = None, output_path: str = None, ) -> None: self.args = args self.config = config - self.c_logger = c_logger - self.tb_logger = tb_logger + self.c_logger = ConsoleLogger() if c_logger is None else c_logger + if tb_logger is None: + self.tb_logger = TensorboardLogger(output_path, model_name=config.model) + self.tb_logger.tb_add_text("model-config", f"
{config.to_json()}
", 0) + else: + self.tb_logger = tb_logger self.output_path = output_path self.total_steps_done = 0 @@ -117,20 +121,20 @@ class TrainerTTS: # setup criterion self.criterion = self.get_criterion(self.config) - if self.use_cuda: - self.model.cuda() - self.criterion.cuda() - # DISTRUBUTED if self.num_gpus > 1: init_distributed( args.rank, self.num_gpus, args.group_id, - self.config.distributed["backend"], - self.config.distributed["url"], + self.config.distributed_backend, + self.config.distributed_url, ) + if self.use_cuda: + self.model.cuda() + self.criterion.cuda() + # scalers for mixed precision training self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None @@ -147,7 +151,7 @@ class TrainerTTS: # DISTRUBUTED if self.num_gpus > 1: - self.model = DDP_th(self.model, device_ids=[args.rank]) + self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank) # count model size num_params = count_parameters(self.model) @@ -377,6 +381,11 @@ class TrainerTTS: "item_idx": item_idx, } + def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]: + if hasattr(self.model, "module"): + return self.model.module.train_step(batch, criterion) + return self.model.train_step(batch, criterion) + def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]: self.on_train_step_start() step_start_time = time.time() @@ -389,7 +398,7 @@ class TrainerTTS: self.optimizer.zero_grad() with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): - outputs, loss_dict = self.model.train_step(batch, self.criterion) + outputs, loss_dict = self._train_step(batch, self.criterion) # check nan loss if torch.isnan(loss_dict["loss"]).any(): @@ -473,7 +482,10 @@ class TrainerTTS: scaler=self.scaler.state_dict() if self.config.mixed_precision else None, ) # training visualizations - figures, audios = self.model.train_log(self.ap, batch, outputs) + if hasattr(self.model, "module"): + figures, audios = self.model.module.train_log(self.ap, batch, outputs) + else: + figures, audios = self.model.train_log(self.ap, batch, outputs) self.tb_logger.tb_train_figures(self.total_steps_done, figures) self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate) self.total_steps_done += 1 @@ -500,12 +512,17 @@ class TrainerTTS: if self.config.tb_model_param_stats: self.tb_logger.tb_model_weights(self.model, self.total_steps_done) + def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]: + if hasattr(self.model, "module"): + return self.model.module.eval_step(batch, self.criterion) + return self.model.eval_step(batch, self.criterion) + def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]: with torch.no_grad(): step_start_time = time.time() with torch.cuda.amp.autocast(enabled=self.config.mixed_precision): - outputs, loss_dict = self.model.eval_step(batch, self.criterion) + outputs, loss_dict = self._eval_step(batch) step_time = time.time() - step_start_time @@ -542,7 +559,10 @@ class TrainerTTS: outputs, _ = self.eval_step(batch, cur_step) # Plot epoch stats and samples from the last batch. if self.args.rank == 0: - figures, eval_audios = self.model.eval_log(self.ap, batch, outputs) + if hasattr(self.model, "module"): + figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs) + else: + figures, eval_audios = self.model.eval_log(self.ap, batch, outputs) self.tb_logger.tb_eval_figures(self.total_steps_done, figures) self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate) @@ -642,7 +662,7 @@ class TrainerTTS: self.train_epoch() if self.config.run_eval: self.eval_epoch() - if epoch >= self.config.test_delay_epochs: + if epoch >= self.config.test_delay_epochs and self.args.rank < 0: self.test_run() self.c_logger.print_epoch_end( epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 72eff2e5..46f919dc 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -87,7 +87,11 @@ def run_model_torch( Dict: model outputs. """ input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) - outputs = model.inference( + if hasattr(model, "module"): + _func = model.module.inference + else: + _func = model.inference + outputs = _func( inputs, cond_input={ "x_lengths": input_lengths, diff --git a/recipes/ljspeech/tacotron2-DDC/tacotron2-DDC.json b/recipes/ljspeech/tacotron2-DDC/tacotron2-DDC.json index 9cdbbd3b..e3531851 100644 --- a/recipes/ljspeech/tacotron2-DDC/tacotron2-DDC.json +++ b/recipes/ljspeech/tacotron2-DDC/tacotron2-DDC.json @@ -36,12 +36,14 @@ "gst_num_heads": 4, "gst_num_style_tokens": 10 }, + "distributed_backend": "gloo", + "distributed_url": "tcp:\/\/localhost:54321", "model": "Tacotron2", "run_name": "ljspeech-ddc", "run_description": "tacotron2 with double decoder consistency.", "batch_size": 64, "eval_batch_size": 16, - "mixed_precision": true, + "mixed_precision": false, "loss_masking": true, "decoder_loss_alpha": 0.25, "postnet_loss_alpha": 0.25, @@ -54,6 +56,7 @@ "run_eval": true, "test_delay_epochs": 10, "test_sentences_file": null, + "max_decoder_steps": 50, "noam_schedule": true, "grad_clip": 0.05, "epochs": 1000, @@ -88,4 +91,4 @@ "phoneme_cache_path": "DEFINE THIS", "use_phonemes": false, "phoneme_language": "en-us" -} \ No newline at end of file +}