mirror of https://github.com/coqui-ai/TTS.git
Fix distribute.py and ddp training
This commit is contained in:
parent
b02c4fe347
commit
ecf5f17dca
|
@ -32,6 +32,7 @@ def main():
|
|||
command.append("--restore_path={}".format(args.restore_path))
|
||||
command.append("--config_path={}".format(args.config_path))
|
||||
command.append("--group_id=group_{}".format(group_id))
|
||||
command.append("--use_ddp=true")
|
||||
command += unargs
|
||||
command.append("")
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from urllib.parse import urlparse
|
|||
|
||||
import fsspec
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
|
@ -83,6 +84,7 @@ class TrainingArgs(Coqpit):
|
|||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
||||
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
||||
use_ddp: bool= field(default=False, metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."})
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
@ -152,7 +154,7 @@ class Trainer:
|
|||
"""
|
||||
|
||||
# set and initialize Pytorch runtime
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark)
|
||||
self.use_cuda, self.num_gpus = setup_torch_training_env(True, cudnn_benchmark, args.use_ddp)
|
||||
if config is None:
|
||||
# parse config from console arguments
|
||||
config, output_path, _, c_logger, dashboard_logger = process_args(args)
|
||||
|
@ -348,8 +350,8 @@ class Trainer:
|
|||
restore_step = checkpoint["step"]
|
||||
return model, optimizer, scaler, restore_step
|
||||
|
||||
@staticmethod
|
||||
def _get_loader(
|
||||
self,
|
||||
model: nn.Module,
|
||||
config: Coqpit,
|
||||
ap: AudioProcessor,
|
||||
|
@ -358,8 +360,12 @@ class Trainer:
|
|||
verbose: bool,
|
||||
num_gpus: int,
|
||||
) -> DataLoader:
|
||||
if hasattr(model, "get_data_loader"):
|
||||
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
||||
if num_gpus > 1:
|
||||
if hasattr(model.module, "get_data_loader"):
|
||||
loader = model.module.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank)
|
||||
else:
|
||||
if hasattr(model, "get_data_loader"):
|
||||
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
||||
return loader
|
||||
|
||||
def get_train_dataloader(self, ap: AudioProcessor, data_items: List, verbose: bool) -> DataLoader:
|
||||
|
@ -387,7 +393,10 @@ class Trainer:
|
|||
Returns:
|
||||
Dict: Formatted batch.
|
||||
"""
|
||||
batch = self.model.format_batch(batch)
|
||||
if self.num_gpus > 1:
|
||||
batch = self.model.module.format_batch(batch)
|
||||
else:
|
||||
batch = self.model.format_batch(batch)
|
||||
if self.use_cuda:
|
||||
for k, v in batch.items():
|
||||
batch[k] = to_cuda(v)
|
||||
|
@ -674,16 +683,20 @@ class Trainer:
|
|||
self.data_train,
|
||||
verbose=True,
|
||||
)
|
||||
self.model.train()
|
||||
if self.num_gpus > 1:
|
||||
self.model.module.train()
|
||||
else:
|
||||
self.model.train()
|
||||
epoch_start_time = time.time()
|
||||
if self.use_cuda:
|
||||
batch_num_steps = int(len(self.train_loader.dataset) / (self.config.batch_size * self.num_gpus))
|
||||
else:
|
||||
batch_num_steps = int(len(self.train_loader.dataset) / self.config.batch_size)
|
||||
self.c_logger.print_train_start()
|
||||
loader_start_time = time.time()
|
||||
for cur_step, batch in enumerate(self.train_loader):
|
||||
loader_start_time = time.time()
|
||||
_, _ = self.train_step(batch, batch_num_steps, cur_step, loader_start_time)
|
||||
loader_start_time = time.time()
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
# Plot self.epochs_done Stats
|
||||
if self.args.rank == 0:
|
||||
|
@ -826,6 +839,9 @@ class Trainer:
|
|||
self.total_steps_done = self.restore_step
|
||||
|
||||
for epoch in range(0, self.config.epochs):
|
||||
if self.num_gpus:
|
||||
# let all processes sync up before starting with a new epoch of training
|
||||
dist.barrier()
|
||||
self.callbacks.on_epoch_start()
|
||||
self.keep_avg_train = KeepAverage()
|
||||
self.keep_avg_eval = KeepAverage() if self.config.run_eval else None
|
||||
|
@ -839,7 +855,8 @@ class Trainer:
|
|||
self.c_logger.print_epoch_end(
|
||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||
)
|
||||
self.save_best_model()
|
||||
if self.args.rank in [None, 0]:
|
||||
self.save_best_model()
|
||||
self.callbacks.on_epoch_end()
|
||||
|
||||
def fit(self) -> None:
|
||||
|
|
|
@ -164,7 +164,7 @@ class BaseTTS(BaseModel):
|
|||
}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int
|
||||
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int=None
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
|
@ -212,7 +212,7 @@ class BaseTTS(BaseModel):
|
|||
else None,
|
||||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
if config.use_phonemes and config.compute_input_seq_cache and rank in [None, 0]:
|
||||
if hasattr(self, "eval_data_items") and is_eval:
|
||||
dataset.items = self.eval_data_items
|
||||
elif hasattr(self, "train_data_items") and not is_eval:
|
||||
|
|
|
@ -9,10 +9,20 @@ from TTS.utils.training import NoamLR
|
|||
def is_apex_available():
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
def setup_torch_training_env(cudnn_enable:bool, cudnn_benchmark:bool, use_ddp:bool=False) -> Tuple[bool, int]:
|
||||
"""Setup PyTorch environment for training.
|
||||
|
||||
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||
Args:
|
||||
cudnn_enable (bool): Enable/disable CUDNN.
|
||||
cudnn_benchmark (bool): Enable/disable CUDNN benchmarking. Better to set to False if input sequence length is
|
||||
variable between batches.
|
||||
use_ddp (bool): DDP flag. True if DDP is enabled, False otherwise.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, int]: is cuda on or off and number of GPUs in the environment.
|
||||
"""
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1:
|
||||
if num_gpus > 1 and not use_ddp:
|
||||
raise RuntimeError(
|
||||
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue