Fix distribute.py and ddp training

This commit is contained in:
Eren Gölge 2021-08-12 22:22:32 +00:00
parent b02c4fe347
commit ecf5f17dca
4 changed files with 40 additions and 12 deletions

View File

@ -32,6 +32,7 @@ def main():
command.append("--restore_path={}".format(args.restore_path)) command.append("--restore_path={}".format(args.restore_path))
command.append("--config_path={}".format(args.config_path)) command.append("--config_path={}".format(args.config_path))
command.append("--group_id=group_{}".format(group_id)) command.append("--group_id=group_{}".format(group_id))
command.append("--use_ddp=true")
command += unargs command += unargs
command.append("") command.append("")

View File

@ -16,6 +16,7 @@ from urllib.parse import urlparse
import fsspec import fsspec
import torch import torch
import torch.distributed as dist
from coqpit import Coqpit from coqpit import Coqpit
from torch import nn from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.nn.parallel import DistributedDataParallel as DDP_th
@ -83,6 +84,7 @@ class TrainingArgs(Coqpit):
config_path: str = field(default="", metadata={"help": "Path to the configuration file."}) config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."}) rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."}) group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
use_ddp: bool= field(default=False, metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."})
class Trainer: class Trainer:
@ -152,7 +154,7 @@ class Trainer:
""" """
# set and initialize Pytorch runtime # set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark) self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
if config is None: if config is None:
# parse config from console arguments # parse config from console arguments
config, output_path, _, c_logger, dashboard_logger = process_args(args) config, output_path, _, c_logger, dashboard_logger = process_args(args)
@ -348,8 +350,8 @@ class Trainer:
restore_step = checkpoint["step"] restore_step = checkpoint["step"]
return model, optimizer, scaler, restore_step return model, optimizer, scaler, restore_step
@staticmethod
def _get_loader( def _get_loader(
self,
model: nn.Module, model: nn.Module,
config: Coqpit, config: Coqpit,
ap: AudioProcessor, ap: AudioProcessor,
@ -358,6 +360,10 @@ class Trainer:
verbose: bool, verbose: bool,
num_gpus: int, num_gpus: int,
) -> DataLoader: ) -> DataLoader:
if num_gpus > 1:
if hasattr(model.module, "get_data_loader"):
loader = model.module.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank)
else:
if hasattr(model, "get_data_loader"): if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus) loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
return loader return loader
@ -387,6 +393,9 @@ class Trainer:
Returns: Returns:
Dict: Formatted batch. Dict: Formatted batch.
""" """
if self.num_gpus > 1:
batch = self.model.module.format_batch(batch)
else:
batch = self.model.format_batch(batch) batch = self.model.format_batch(batch)
if self.use_cuda: if self.use_cuda:
for k, v in batch.items(): for k, v in batch.items():
@ -674,6 +683,9 @@ class Trainer:
self.data_train, self.data_train,
verbose=True, verbose=True,
) )
if self.num_gpus > 1:
self.model.module.train()
else:
self.model.train() self.model.train()
epoch_start_time = time.time() epoch_start_time = time.time()
if self.use_cuda: if self.use_cuda:
@ -681,9 +693,10 @@ class Trainer:
else: else:
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size) batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start() self.c_logger.print_train_start()
for cur_step, batch in enumerate(self.train_loader):
loader_start_time = time.time() loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader):
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time) _, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
loader_start_time = time.time()
epoch_time = time.time() - epoch_start_time epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats # Plot self.epochs_done Stats
if self.args.rank == 0: if self.args.rank == 0:
@ -826,6 +839,9 @@ class Trainer:
self.total_steps_done = self.restore_step self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs): for epoch in range(0, self.config.epochs):
if self.num_gpus:
# let all processes sync up before starting with a new epoch of training
dist.barrier()
self.callbacks.on_epoch_start() self.callbacks.on_epoch_start()
self.keep_avg_train = KeepAverage() self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
@ -839,6 +855,7 @@ class Trainer:
self.c_logger.print_epoch_end( self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
) )
if self.args.rank in [None, 0]:
self.save_best_model() self.save_best_model()
self.callbacks.on_epoch_end() self.callbacks.on_epoch_end()

View File

@ -164,7 +164,7 @@ class BaseTTS(BaseModel):
} }
def get_data_loader( def get_data_loader(
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int=None
) -> "DataLoader": ) -> "DataLoader":
if is_eval and not config.run_eval: if is_eval and not config.run_eval:
loader = None loader = None
@ -212,7 +212,7 @@ class BaseTTS(BaseModel):
else None, else None,
) )
if config.use_phonemes and config.compute_input_seq_cache: if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
if hasattr(self, "eval_data_items") and is_eval: if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval: elif hasattr(self, "train_data_items") and not is_eval:

View File

@ -9,10 +9,20 @@ from TTS.utils.training import NoamLR
def is_apex_available(): def is_apex_available():
return importlib.util.find_spec("apex") is not None return importlib.util.find_spec("apex") is not None
def setup_torch_training_env(cudnn_enable:bool, cudnn_benchmark:bool, use_ddp:bool=False) -> Tuple[bool, int]:
"""Setup PyTorch environment for training.
def setup_torch_training_env(cudnn_enable, cudnn_benchmark): Args:
cudnn_enable (bool): Enable/disable CUDNN.
cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is
variable between batches.
use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise.
Returns:
Tuple[bool, int]: is cuda on or off and number of GPUs in the environment.
"""
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
if num_gpus > 1: if num_gpus > 1 and not use_ddp:
raise RuntimeError( raise RuntimeError(
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`." f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
) )