diff --git a/TTS/trainer.py b/TTS/trainer.py index 014a4340..32e561d6 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -84,7 +84,10 @@ class TrainingArgs(Coqpit): config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) - use_ddp: bool= field(default=False, metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."}) + use_ddp: bool = field( + default=False, + metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."}, + ) class Trainer: @@ -362,7 +365,9 @@ class Trainer: ) -> DataLoader: if num_gpus > 1: if hasattr(model.module, "get_data_loader"): - loader = model.module.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank) + loader = model.module.get_data_loader( + config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank + ) else: if hasattr(model, "get_data_loader"): loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) @@ -797,6 +802,7 @@ class Trainer: loader_time = time.time() - loader_start_time self.keep_avg_eval.update_values({"avg_loader_time": loader_time}) outputs, _ = self.eval_step(batch, cur_step) + loader_start_time = time.time() # plot epoch stats, artifacts and figures if self.args.rank == 0: figures, audios = None, None @@ -839,7 +845,7 @@ class Trainer: self.total_steps_done = self.restore_step for epoch in range(0, self.config.epochs): - if self.num_gpus: + if self.num_gpus > 1: # let all processes sync up before starting with a new epoch of training dist.barrier() self.callbacks.on_epoch_start() @@ -868,6 +874,9 @@ class Trainer: self.callbacks.on_keyboard_interrupt() # if the output folder is empty remove the run. remove_experiment_folder(self.output_path) + # clear the DDP processes + if self.num_gpus > 1: + dist.destroy_process_group() # finish the wandb run and sync data self.dashboard_logger.finish() # stop without error signal diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 3b6e3f90..92ab0697 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -2,6 +2,7 @@ import os from typing import Dict, List, Tuple import torch +import torch.distributed as dist from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader @@ -164,7 +165,14 @@ class BaseTTS(BaseModel): } def get_data_loader( - self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int=None + self, + config: Coqpit, + ap: AudioProcessor, + is_eval: bool, + data_items: List, + verbose: bool, + num_gpus: int, + rank: int = None, ) -> "DataLoader": if is_eval and not config.run_eval: loader = None @@ -228,6 +236,10 @@ class BaseTTS(BaseModel): else: self.train_data_items = dataset.items + # halt DDP processes for the main process to finish computing the phoneme cache + if num_gpus > 1: + dist.barrier() + dataset.sort_items() sampler = DistributedSampler(dataset) if num_gpus > 1 else None diff --git a/TTS/utils/trainer_utils.py b/TTS/utils/trainer_utils.py index 90cc7d81..005114d1 100644 --- a/TTS/utils/trainer_utils.py +++ b/TTS/utils/trainer_utils.py @@ -1,5 +1,5 @@ import importlib -from typing import Dict, List +from typing import Dict, List, Tuple import torch @@ -9,7 +9,8 @@ from TTS.utils.training import NoamLR def is_apex_available(): return importlib.util.find_spec("apex") is not None -def setup_torch_training_env(cudnn_enable:bool, cudnn_benchmark:bool, use_ddp:bool=False) -> Tuple[bool, int]: + +def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]: """Setup PyTorch environment for training. Args: