1) Add ExponentialLR

This commit is contained in:
rishikksh20 2021-02-25 02:31:28 +05:30 committed by Eren Gölge
parent c20a6b1185
commit 1535777f64
2 changed files with 25 additions and 3 deletions

View File

@ -112,11 +112,13 @@
"disc_clip_grad": -1, // Discriminator gradient clipping threshold.
"lr_scheduler_gen": "ExponentialLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"lr_scheduler_gen_params": {
"gamma": 0.999
"gamma": 0.999,
"last_epoch": -1
},
"lr_scheduler_disc": "ExponentialLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"lr_scheduler_disc_params": {
"gamma": 0.999
"gamma": 0.999,
"last_epoch": -1
},
"lr_gen": 0.0002, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_disc": 0.0002,

View File

@ -1,3 +1,4 @@
import torch
from torch import nn
from TTS.vocoder.layers.hifigan import MRF
@ -7,6 +8,9 @@ class Generator(nn.Module):
def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4],
resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]):
super(Generator, self).__init__()
self.inference_padding = 2
self.input = nn.Sequential(
nn.ReflectionPad1d(3),
nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7))
@ -39,6 +43,14 @@ class Generator(nn.Module):
out = self.output(x2)
return out
def inference(self, c):
c = c.to(self.layers[1].weight.device)
c = torch.nn.functional.pad(
c,
(self.inference_padding, self.inference_padding),
'replicate')
return self.forward(c)
def remove_weight_norm(self):
nn.utils.remove_weight_norm(self.input[1])
nn.utils.remove_weight_norm(self.output[2])
@ -48,4 +60,12 @@ class Generator(nn.Module):
try:
nn.utils.remove_weight_norm(layer)
except:
layer.remove_weight_norm()
layer.remove_weight_norm()
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
if eval:
self.eval()
assert not self.training
self.remove_weight_norm()