mirror of https://github.com/coqui-ai/TTS.git
adjust `distribute.py` for the `train_tts.py`
This commit is contained in:
parent
614738cc85
commit
59be1b9af1
|
@ -30,7 +30,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args, unargs = parser.parse_known_args()
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
||||
|
@ -42,6 +42,7 @@ def main():
|
|||
command.append("--restore_path={}".format(args.restore_path))
|
||||
command.append("--config_path={}".format(args.config_path))
|
||||
command.append("--group_id=group_{}".format(group_id))
|
||||
command += unargs
|
||||
command.append("")
|
||||
|
||||
# run processes
|
||||
|
|
|
@ -65,15 +65,19 @@ class TrainerTTS:
|
|||
self,
|
||||
args: Union[Coqpit, Namespace],
|
||||
config: Coqpit,
|
||||
c_logger: ConsoleLogger,
|
||||
tb_logger: TensorboardLogger,
|
||||
c_logger: ConsoleLogger = None,
|
||||
tb_logger: TensorboardLogger = None,
|
||||
model: nn.Module = None,
|
||||
output_path: str = None,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.c_logger = c_logger
|
||||
self.tb_logger = tb_logger
|
||||
self.c_logger = ConsoleLogger() if c_logger is None else c_logger
|
||||
if tb_logger is None:
|
||||
self.tb_logger = TensorboardLogger(output_path, model_name=config.model)
|
||||
self.tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
else:
|
||||
self.tb_logger = tb_logger
|
||||
self.output_path = output_path
|
||||
|
||||
self.total_steps_done = 0
|
||||
|
@ -117,20 +121,20 @@ class TrainerTTS:
|
|||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.config)
|
||||
|
||||
if self.use_cuda:
|
||||
self.model.cuda()
|
||||
self.criterion.cuda()
|
||||
|
||||
# DISTRUBUTED
|
||||
if self.num_gpus > 1:
|
||||
init_distributed(
|
||||
args.rank,
|
||||
self.num_gpus,
|
||||
args.group_id,
|
||||
self.config.distributed["backend"],
|
||||
self.config.distributed["url"],
|
||||
self.config.distributed_backend,
|
||||
self.config.distributed_url,
|
||||
)
|
||||
|
||||
if self.use_cuda:
|
||||
self.model.cuda()
|
||||
self.criterion.cuda()
|
||||
|
||||
# scalers for mixed precision training
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
|
||||
|
||||
|
@ -147,7 +151,7 @@ class TrainerTTS:
|
|||
|
||||
# DISTRUBUTED
|
||||
if self.num_gpus > 1:
|
||||
self.model = DDP_th(self.model, device_ids=[args.rank])
|
||||
self.model = DDP_th(self.model, device_ids=[args.rank], output_device=args.rank)
|
||||
|
||||
# count model size
|
||||
num_params = count_parameters(self.model)
|
||||
|
@ -377,6 +381,11 @@ class TrainerTTS:
|
|||
"item_idx": item_idx,
|
||||
}
|
||||
|
||||
def _train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
if hasattr(self.model, "module"):
|
||||
return self.model.module.train_step(batch, criterion)
|
||||
return self.model.train_step(batch, criterion)
|
||||
|
||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||
self.on_train_step_start()
|
||||
step_start_time = time.time()
|
||||
|
@ -389,7 +398,7 @@ class TrainerTTS:
|
|||
self.optimizer.zero_grad()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||
outputs, loss_dict = self.model.train_step(batch, self.criterion)
|
||||
outputs, loss_dict = self._train_step(batch, self.criterion)
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
|
@ -473,7 +482,10 @@ class TrainerTTS:
|
|||
scaler=self.scaler.state_dict() if self.config.mixed_precision else None,
|
||||
)
|
||||
# training visualizations
|
||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||
if hasattr(self.model, "module"):
|
||||
figures, audios = self.model.module.train_log(self.ap, batch, outputs)
|
||||
else:
|
||||
figures, audios = self.model.train_log(self.ap, batch, outputs)
|
||||
self.tb_logger.tb_train_figures(self.total_steps_done, figures)
|
||||
self.tb_logger.tb_train_audios(self.total_steps_done, {"TrainAudio": audios}, self.ap.sample_rate)
|
||||
self.total_steps_done += 1
|
||||
|
@ -500,12 +512,17 @@ class TrainerTTS:
|
|||
if self.config.tb_model_param_stats:
|
||||
self.tb_logger.tb_model_weights(self.model, self.total_steps_done)
|
||||
|
||||
def _eval_step(self, batch: Dict) -> Tuple[Dict, Dict]:
|
||||
if hasattr(self.model, "module"):
|
||||
return self.model.module.eval_step(batch, self.criterion)
|
||||
return self.model.eval_step(batch, self.criterion)
|
||||
|
||||
def eval_step(self, batch: Dict, step: int) -> Tuple[Dict, Dict]:
|
||||
with torch.no_grad():
|
||||
step_start_time = time.time()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=self.config.mixed_precision):
|
||||
outputs, loss_dict = self.model.eval_step(batch, self.criterion)
|
||||
outputs, loss_dict = self._eval_step(batch)
|
||||
|
||||
step_time = time.time() - step_start_time
|
||||
|
||||
|
@ -542,7 +559,10 @@ class TrainerTTS:
|
|||
outputs, _ = self.eval_step(batch, cur_step)
|
||||
# Plot epoch stats and samples from the last batch.
|
||||
if self.args.rank == 0:
|
||||
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
||||
if hasattr(self.model, "module"):
|
||||
figures, eval_audios = self.model.module.eval_log(self.ap, batch, outputs)
|
||||
else:
|
||||
figures, eval_audios = self.model.eval_log(self.ap, batch, outputs)
|
||||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
||||
|
||||
|
@ -642,7 +662,7 @@ class TrainerTTS:
|
|||
self.train_epoch()
|
||||
if self.config.run_eval:
|
||||
self.eval_epoch()
|
||||
if epoch >= self.config.test_delay_epochs:
|
||||
if epoch >= self.config.test_delay_epochs and self.args.rank < 0:
|
||||
self.test_run()
|
||||
self.c_logger.print_epoch_end(
|
||||
epoch, self.keep_avg_eval.avg_values if self.config.run_eval else self.keep_avg_train.avg_values
|
||||
|
|
|
@ -86,7 +86,11 @@ def run_model_torch(
|
|||
Dict: model outputs.
|
||||
"""
|
||||
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
|
||||
outputs = model.inference(
|
||||
if hasattr(model, "module"):
|
||||
_func = model.module.inference
|
||||
else:
|
||||
_func = model.inference
|
||||
outputs = _func(
|
||||
inputs,
|
||||
cond_input={
|
||||
"x_lengths": input_lengths,
|
||||
|
|
|
@ -36,12 +36,14 @@
|
|||
"gst_num_heads": 4,
|
||||
"gst_num_style_tokens": 10
|
||||
},
|
||||
"distributed_backend": "gloo",
|
||||
"distributed_url": "tcp:\/\/localhost:54321",
|
||||
"model": "Tacotron2",
|
||||
"run_name": "ljspeech-ddc",
|
||||
"run_description": "tacotron2 with double decoder consistency.",
|
||||
"batch_size": 64,
|
||||
"eval_batch_size": 16,
|
||||
"mixed_precision": true,
|
||||
"mixed_precision": false,
|
||||
"loss_masking": true,
|
||||
"decoder_loss_alpha": 0.25,
|
||||
"postnet_loss_alpha": 0.25,
|
||||
|
@ -54,6 +56,7 @@
|
|||
"run_eval": true,
|
||||
"test_delay_epochs": 10,
|
||||
"test_sentences_file": null,
|
||||
"max_decoder_steps": 50,
|
||||
"noam_schedule": true,
|
||||
"grad_clip": 0.05,
|
||||
"epochs": 1000,
|
||||
|
@ -88,4 +91,4 @@
|
|||
"phoneme_cache_path": "DEFINE THIS",
|
||||
"use_phonemes": false,
|
||||
"phoneme_language": "en-us"
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue