mirror of https://github.com/coqui-ai/TTS.git
chore: address pytorch deprecations
torch.range(a, b) == torch.arange(a, b+1) meshgrid indexing: https://github.com/pytorch/pytorch/issues/50276 checkpoint use_reentrant: https://dev-discuss.pytorch.org/t/bc-breaking-update-to-torch-utils-checkpoint-not-passing-in-use-reentrant-flag-will-raise-an-error/1745 optimizer.step() before scheduler.step(): https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
This commit is contained in:
parent
a755328e49
commit
c5241d71ab
|
@ -161,9 +161,6 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# dispatch data to GPU
|
||||
|
@ -182,6 +179,10 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ class EncoderDataset(Dataset):
|
|||
logger.info(" | Number of instances: %d", len(self.items))
|
||||
logger.info(" | Sequence length: %d", self.seq_len)
|
||||
logger.info(" | Number of classes: %d", len(self.classes))
|
||||
logger.info(" | Classes: %d", self.classes)
|
||||
logger.info(" | Classes: %s", self.classes)
|
||||
|
||||
def load_wav(self, filename):
|
||||
audio = self.ap.load_wav(filename, sr=self.ap.sample_rate)
|
||||
|
|
|
@ -255,7 +255,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def _make_ga_mask(ilen, olen, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen), indexing="ij")
|
||||
grid_x, grid_y = grid_x.float(), grid_y.float()
|
||||
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2)))
|
||||
|
||||
|
|
|
@ -128,7 +128,8 @@ class NeuralHMM(nn.Module):
|
|||
# Get mean, std and transition vector from decoder for this timestep
|
||||
# Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop
|
||||
if self.use_grad_checkpointing and self.training:
|
||||
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs)
|
||||
# TODO: use_reentrant=False is recommended
|
||||
mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs, use_reentrant=True)
|
||||
else:
|
||||
mean, std, transition_vector = self.output_net(h_memory, inputs)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ def test_sequence_mask():
|
|||
|
||||
|
||||
def test_segment():
|
||||
x = T.range(0, 11)
|
||||
x = T.arange(0, 12)
|
||||
x = x.repeat(8, 1).unsqueeze(1)
|
||||
segment_ids = T.randint(0, 7, (8,))
|
||||
|
||||
|
|
Loading…
Reference in New Issue