adjust `distribute.py` for the `train_tts.py`

This commit is contained in:
Eren Gölge 2021-06-06 15:20:17 +02:00
parent 614738cc85
commit 59be1b9af1
4 changed files with 48 additions and 20 deletions

View File

@ -30,7 +30,7 @@ def main():
parser.add_argument(
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
)
args = parser.parse_args()
args, unargs = parser.parse_known_args()
num_gpus = torch.cuda.device_count()
group_id = time.strftime("%Y_%m_%d-%H%M%S")
@ -42,6 +42,7 @@ def main():
command.append("--restore_path={}".format(args.restore_path))
command.append("--config_path={}".format(args.config_path))
command.append("--group_id=group_{}".format(group_id))
command += unargs
command.append("")
# run processes

View File

@ -65,15 +65,19 @@ class TrainerTTS:
self,
args: Union[Coqpit, Namespace],
config: Coqpit,
c_logger: ConsoleLogger,
tb_logger: TensorboardLogger,
c_logger: ConsoleLogger = None,
tb_logger: TensorboardLogger = None,
model: nn.Module = None,
output_path: str = None,
) -> None:
self.args = args
self.config = config
self.c_logger = c_logger
self.tb_logger = tb_logger
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
if tb_logger is None:
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
else:
self.tb_logger = tb_logger
self.output_path = output_path
self.total_steps_done = 0
@ -117,20 +121,20 @@ class TrainerTTS:
# setup criterion
self.criterion = self.get_criterion(self.config)
if self.use_cuda:
self.model.cuda()
self.criterion.cuda()
# DISTRUBUTED
if self.num_gpus > 1:
init_distributed(
args.rank,
self.num_gpus,
args.group_id,
self.config.distributed["backend"],
self.config.distributed["url"],
self.config.distributed_backend,
self.config.distributed_url,
)
if self.use_cuda:
self.model.cuda()
self.criterion.cuda()
# scalers for mixed precision training
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
@ -147,7 +151,7 @@ class TrainerTTS:
# DISTRUBUTED
if self.num_gpus > 1:
self.model = DDP_th(self.model, device_ids=[args.rank])
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
# count model size
num_params = count_parameters(self.model)
@ -377,6 +381,11 @@ class TrainerTTS:
"item_idx": item_idx,
}
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
if hasattr(self.model, "module"):
return self.model.module.train_step(batch, criterion)
return self.model.train_step(batch, criterion)
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
self.on_train_step_start()
step_start_time = time.time()
@ -389,7 +398,7 @@ class TrainerTTS:
self.optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self.model.train_step(batch, self.criterion)
outputs, loss_dict = self._train_step(batch, self.criterion)
# check nan loss
if torch.isnan(loss_dict["loss"]).any():
@ -473,7 +482,10 @@ class TrainerTTS:
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
)
# training visualizations
figures, audios = self.model.train_log(self.ap, batch, outputs)
if hasattr(self.model, "module"):
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
else:
figures, audios = self.model.train_log(self.ap, batch, outputs)
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
self.total_steps_done += 1
@ -500,12 +512,17 @@ class TrainerTTS:
if self.config.tb_model_param_stats:
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
if hasattr(self.model, "module"):
return self.model.module.eval_step(batch, self.criterion)
return self.model.eval_step(batch, self.criterion)
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
with torch.no_grad():
step_start_time = time.time()
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
outputs, loss_dict = self.model.eval_step(batch, self.criterion)
outputs, loss_dict = self._eval_step(batch)
step_time = time.time() - step_start_time
@ -542,7 +559,10 @@ class TrainerTTS:
outputs, _ = self.eval_step(batch, cur_step)
# Plot epoch stats and samples from the last batch.
if self.args.rank == 0:
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
if hasattr(self.model, "module"):
figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs)
else:
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
@ -642,7 +662,7 @@ class TrainerTTS:
self.train_epoch()
if self.config.run_eval:
self.eval_epoch()
if epoch >= self.config.test_delay_epochs:
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
self.test_run()
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values

View File

@ -86,7 +86,11 @@ def run_model_torch(
Dict: model outputs.
"""
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
outputs = model.inference(
if hasattr(model, "module"):
_func = model.module.inference
else:
_func = model.inference
outputs = _func(
inputs,
cond_input={
"x_lengths": input_lengths,

View File

@ -36,12 +36,14 @@
"gst_num_heads": 4,
"gst_num_style_tokens": 10
},
"distributed_backend": "gloo",
"distributed_url": "tcp:\/\/localhost:54321",
"model": "Tacotron2",
"run_name": "ljspeech-ddc",
"run_description": "tacotron2 with double decoder consistency.",
"batch_size": 64,
"eval_batch_size": 16,
"mixed_precision": true,
"mixed_precision": false,
"loss_masking": true,
"decoder_loss_alpha": 0.25,
"postnet_loss_alpha": 0.25,
@ -54,6 +56,7 @@
"run_eval": true,
"test_delay_epochs": 10,
"test_sentences_file": null,
"max_decoder_steps": 50,
"noam_schedule": true,
"grad_clip": 0.05,
"epochs": 1000,
@ -88,4 +91,4 @@
"phoneme_cache_path": "DEFINE THIS",
"use_phonemes": false,
"phoneme_language": "en-us"
}
}