mirror of https://github.com/coqui-ai/TTS.git
1) Add ExponentialLR
This commit is contained in:
parent
c20a6b1185
commit
1535777f64
|
@ -112,11 +112,13 @@
|
|||
"disc_clip_grad": -1, // Discriminator gradient clipping threshold.
|
||||
"lr_scheduler_gen": "ExponentialLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
||||
"lr_scheduler_gen_params": {
|
||||
"gamma": 0.999
|
||||
"gamma": 0.999,
|
||||
"last_epoch": -1
|
||||
},
|
||||
"lr_scheduler_disc": "ExponentialLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
|
||||
"lr_scheduler_disc_params": {
|
||||
"gamma": 0.999
|
||||
"gamma": 0.999,
|
||||
"last_epoch": -1
|
||||
},
|
||||
"lr_gen": 0.0002, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_disc": 0.0002,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from TTS.vocoder.layers.hifigan import MRF
|
||||
|
||||
|
@ -7,6 +8,9 @@ class Generator(nn.Module):
|
|||
def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4],
|
||||
resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.inference_padding = 2
|
||||
|
||||
self.input = nn.Sequential(
|
||||
nn.ReflectionPad1d(3),
|
||||
nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7))
|
||||
|
@ -39,6 +43,14 @@ class Generator(nn.Module):
|
|||
out = self.output(x2)
|
||||
return out
|
||||
|
||||
def inference(self, c):
|
||||
c = c.to(self.layers[1].weight.device)
|
||||
c = torch.nn.functional.pad(
|
||||
c,
|
||||
(self.inference_padding, self.inference_padding),
|
||||
'replicate')
|
||||
return self.forward(c)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
nn.utils.remove_weight_norm(self.input[1])
|
||||
nn.utils.remove_weight_norm(self.output[2])
|
||||
|
@ -48,4 +60,12 @@ class Generator(nn.Module):
|
|||
try:
|
||||
nn.utils.remove_weight_norm(layer)
|
||||
except:
|
||||
layer.remove_weight_norm()
|
||||
layer.remove_weight_norm()
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
self.remove_weight_norm()
|
Loading…
Reference in New Issue