chore: address pytorch deprecations

torch.range(a, b) == torch.arange(a, b+1)

meshgrid indexing: https://github.com/pytorch/pytorch/issues/50276

checkpoint use_reentrant:
https://dev-discuss.pytorch.org/t/bc-breaking-update-to-torch-utils-checkpoint-not-passing-in-use-reentrant-flag-will-raise-an-error/1745

optimizer.step() before scheduler.step():
https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
This commit is contained in:
Enno Hermann 2024-06-26 00:24:04 +02:00
parent a755328e49
commit c5241d71ab
5 changed files with 9 additions and 7 deletions

View File

@ -161,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
loader_time = time.time() - end_time
global_step += 1
# setup lr
if c.lr_decay:
scheduler.step()
optimizer.zero_grad()
# dispatch data to GPU
@ -182,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
# setup lr
if c.lr_decay:
scheduler.step()
step_time = time.time() - start_time
epoch_time += step_time

View File

@ -55,7 +55,7 @@ class EncoderDataset(Dataset):
logger.info(" | Number of instances: %d", len(self.items))
logger.info(" | Sequence length: %d", self.seq_len)
logger.info(" | Number of classes: %d", len(self.classes))
logger.info(" | Classes: %d", self.classes)
logger.info(" | Classes: %s", self.classes)
def load_wav(self, filename):
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)

View File

@ -255,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module):
@staticmethod
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij")
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))

View File

@ -128,7 +128,8 @@ class NeuralHMM(nn.Module):
# Get mean, std and transition vector from decoder for this timestep
# Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop
if self.use_grad_checkpointing and self.training:
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs)
# TODO: use_reentrant=False is recommended
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True)
else:
mean, std, transition_vector = self.output_net(h_memory, inputs)

View File

@ -31,7 +31,7 @@ def test_sequence_mask():
def test_segment():
x = T.range(0, 11)
x = T.arange(0, 12)
x = x.repeat(8, 1).unsqueeze(1)
segment_ids = T.randint(0, 7, (8,))