diff --git a/TTS/vocoder/configs/modified_hifigan.json b/TTS/vocoder/configs/modified_hifigan.json index b3d6d22f..feeb542e 100644 --- a/TTS/vocoder/configs/modified_hifigan.json +++ b/TTS/vocoder/configs/modified_hifigan.json @@ -112,11 +112,13 @@ "disc_clip_grad": -1, // Discriminator gradient clipping threshold. "lr_scheduler_gen": "ExponentialLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate "lr_scheduler_gen_params": { - "gamma": 0.999 + "gamma": 0.999, + "last_epoch": -1 }, "lr_scheduler_disc": "ExponentialLR", // one of the schedulers from https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate "lr_scheduler_disc_params": { - "gamma": 0.999 + "gamma": 0.999, + "last_epoch": -1 }, "lr_gen": 0.0002, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr_disc": 0.0002, diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 744f8e86..e15652ec 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -1,3 +1,4 @@ +import torch from torch import nn from TTS.vocoder.layers.hifigan import MRF @@ -7,6 +8,9 @@ class Generator(nn.Module): def __init__(self, in_channels=80, out_channels=1, base_channels=512, upsample_kernel=[16, 16, 4, 4], resblock_kernel_sizes=[3, 7, 11], resblock_dilation_sizes=[1, 3, 5]): super(Generator, self).__init__() + + self.inference_padding = 2 + self.input = nn.Sequential( nn.ReflectionPad1d(3), nn.utils.weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=7)) @@ -39,6 +43,14 @@ class Generator(nn.Module): out = self.output(x2) return out + def inference(self, c): + c = c.to(self.layers[1].weight.device) + c = torch.nn.functional.pad( + c, + (self.inference_padding, self.inference_padding), + 'replicate') + return self.forward(c) + def remove_weight_norm(self): nn.utils.remove_weight_norm(self.input[1]) nn.utils.remove_weight_norm(self.output[2]) @@ -48,4 +60,12 @@ class Generator(nn.Module): try: nn.utils.remove_weight_norm(layer) except: - layer.remove_weight_norm() \ No newline at end of file + layer.remove_weight_norm() + + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() \ No newline at end of file