From c5241d71ab0628261acb070cf339b8cd1a52f32e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 26 Jun 2024 00:24:04 +0200 Subject: [PATCH] chore: address pytorch deprecations torch.range(a, b) == torch.arange(a, b+1) meshgrid indexing: https://github.com/pytorch/pytorch/issues/50276 checkpoint use_reentrant: https://dev-discuss.pytorch.org/t/bc-breaking-update-to-torch-utils-checkpoint-not-passing-in-use-reentrant-flag-will-raise-an-error/1745 optimizer.step() before scheduler.step(): https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate --- TTS/bin/train_encoder.py | 7 ++++--- TTS/encoder/dataset.py | 2 +- TTS/tts/layers/losses.py | 2 +- TTS/tts/layers/overflow/neural_hmm.py | 3 ++- tests/tts_tests/test_helpers.py | 2 +- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index c0292743..49b450cf 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -161,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, loader_time = time.time() - end_time global_step += 1 - # setup lr - if c.lr_decay: - scheduler.step() optimizer.zero_grad() # dispatch data to GPU @@ -182,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() + # setup lr + if c.lr_decay: + scheduler.step() + step_time = time.time() - start_time epoch_time += step_time diff --git a/TTS/encoder/dataset.py b/TTS/encoder/dataset.py index 81385c6c..bb780e3c 100644 --- a/TTS/encoder/dataset.py +++ b/TTS/encoder/dataset.py @@ -55,7 +55,7 @@ class EncoderDataset(Dataset): logger.info(" | Number of instances: %d", len(self.items)) logger.info(" | Sequence length: %d", self.seq_len) logger.info(" | Number of classes: %d", len(self.classes)) - logger.info(" | Classes: %d", self.classes) + logger.info(" | Classes: %s", self.classes) def load_wav(self, filename): audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index cd6cd0ae..5ebed81d 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -255,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module): @staticmethod def _make_ga_mask(ilen, olen, sigma): - grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) + grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij") grid_x, grid_y = grid_x.float(), grid_y.float() return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))) diff --git a/TTS/tts/layers/overflow/neural_hmm.py b/TTS/tts/layers/overflow/neural_hmm.py index 0631ba98..a12becef 100644 --- a/TTS/tts/layers/overflow/neural_hmm.py +++ b/TTS/tts/layers/overflow/neural_hmm.py @@ -128,7 +128,8 @@ class NeuralHMM(nn.Module): # Get mean, std and transition vector from decoder for this timestep # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop if self.use_grad_checkpointing and self.training: - mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs) + # TODO: use_reentrant=False is recommended + mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True) else: mean, std, transition_vector = self.output_net(h_memory, inputs) diff --git a/tests/tts_tests/test_helpers.py b/tests/tts_tests/test_helpers.py index dbd7f54e..d07efa36 100644 --- a/tests/tts_tests/test_helpers.py +++ b/tests/tts_tests/test_helpers.py @@ -31,7 +31,7 @@ def test_sequence_mask(): def test_segment(): - x = T.range(0, 11) + x = T.arange(0, 12) x = x.repeat(8, 1).unsqueeze(1) segment_ids = T.randint(0, 7, (8,))