Fix distribute.py and ddp training

This commit is contained in:
Eren Gölge 2021-08-12 22:22:32 +00:00
parent b02c4fe347
commit ecf5f17dca
4 changed files with 40 additions and 12 deletions

View File

@ -32,6 +32,7 @@ def main():
command.append("--restore_path={}".format(args.restore_path))
command.append("--config_path={}".format(args.config_path))
command.append("--group_id=group_{}".format(group_id))
command.append("--use_ddp=true")
command += unargs
command.append("")

View File

@ -16,6 +16,7 @@ from urllib.parse import urlparse
import fsspec
import torch
import torch.distributed as dist
from coqpit import Coqpit
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
@ -83,6 +84,7 @@ class TrainingArgs(Coqpit):
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
use_ddp: bool= field(default=False, metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."})
class Trainer:
@ -152,7 +154,7 @@ class Trainer:
"""
# set and initialize Pytorch runtime
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
if config is None:
# parse config from console arguments
config, output_path, _, c_logger, dashboard_logger = process_args(args)
@ -348,8 +350,8 @@ class Trainer:
restore_step = checkpoint["step"]
return model, optimizer, scaler, restore_step
@staticmethod
def _get_loader(
self,
model: nn.Module,
config: Coqpit,
ap: AudioProcessor,
@ -358,8 +360,12 @@ class Trainer:
verbose: bool,
num_gpus: int,
) -> DataLoader:
if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
if num_gpus > 1:
if hasattr(model.module, "get_data_loader"):
loader = model.module.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank)
else:
if hasattr(model, "get_data_loader"):
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
return loader
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
@ -387,7 +393,10 @@ class Trainer:
Returns:
Dict: Formatted batch.
"""
batch = self.model.format_batch(batch)
if self.num_gpus > 1:
batch = self.model.module.format_batch(batch)
else:
batch = self.model.format_batch(batch)
if self.use_cuda:
for k, v in batch.items():
batch[k] = to_cuda(v)
@ -674,16 +683,20 @@ class Trainer:
self.data_train,
verbose=True,
)
self.model.train()
if self.num_gpus > 1:
self.model.module.train()
else:
self.model.train()
epoch_start_time = time.time()
if self.use_cuda:
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
else:
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
self.c_logger.print_train_start()
loader_start_time = time.time()
for cur_step, batch in enumerate(self.train_loader):
loader_start_time = time.time()
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
loader_start_time = time.time()
epoch_time = time.time() - epoch_start_time
# Plot self.epochs_done Stats
if self.args.rank == 0:
@ -826,6 +839,9 @@ class Trainer:
self.total_steps_done = self.restore_step
for epoch in range(0, self.config.epochs):
if self.num_gpus:
# let all processes sync up before starting with a new epoch of training
dist.barrier()
self.callbacks.on_epoch_start()
self.keep_avg_train = KeepAverage()
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
@ -839,7 +855,8 @@ class Trainer:
self.c_logger.print_epoch_end(
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
)
self.save_best_model()
if self.args.rank in [None, 0]:
self.save_best_model()
self.callbacks.on_epoch_end()
def fit(self) -> None:

View File

@ -164,7 +164,7 @@ class BaseTTS(BaseModel):
}
def get_data_loader(
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int=None
) -> "DataLoader":
if is_eval and not config.run_eval:
loader = None
@ -212,7 +212,7 @@ class BaseTTS(BaseModel):
else None,
)
if config.use_phonemes and config.compute_input_seq_cache:
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
if hasattr(self, "eval_data_items") and is_eval:
dataset.items = self.eval_data_items
elif hasattr(self, "train_data_items") and not is_eval:

View File

@ -9,10 +9,20 @@ from TTS.utils.training import NoamLR
def is_apex_available():
return importlib.util.find_spec("apex") is not None
def setup_torch_training_env(cudnn_enable:bool, cudnn_benchmark:bool, use_ddp:bool=False) -> Tuple[bool, int]:
"""Setup PyTorch environment for training.
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
Args:
cudnn_enable (bool): Enable/disable CUDNN.
cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is
variable between batches.
use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise.
Returns:
Tuple[bool, int]: is cuda on or off and number of GPUs in the environment.
"""
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
if num_gpus > 1 and not use_ddp:
raise RuntimeError(
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
)