From ca3743539a85f753f64405ba77179ee2047297df Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 Jan 2021 02:12:29 +0000 Subject: [PATCH] load_checkpoint func for vocoder models --- TTS/vocoder/models/melgan_generator.py | 8 ++++++++ .../models/parallel_wavegan_generator.py | 10 ++++++++++ TTS/vocoder/models/wavegrad.py | 19 +++++++++++++++++++ TTS/vocoder/utils/generic_utils.py | 1 + 4 files changed, 38 insertions(+) diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index 9ab98cef..e5fd46eb 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -95,3 +95,11 @@ class MelganGenerator(nn.Module): nn.utils.remove_weight_norm(layer) except ValueError: layer.remove_weight_norm() + + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training + self.remove_weight_norm() diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index d703f327..f5ed7712 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -39,6 +39,7 @@ class ParallelWaveganGenerator(torch.nn.Module): self.upsample_factors = upsample_factors self.upsample_scale = np.prod(upsample_factors) self.inference_padding = inference_padding + self.use_weight_norm = use_weight_norm # check the number of layers and stacks assert num_res_blocks % stacks == 0 @@ -156,3 +157,12 @@ class ParallelWaveganGenerator(torch.nn.Module): def receptive_field_size(self): return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training + if self.use_weight_norm: + self.remove_weight_norm() diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index da491771..bb9d04b8 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -175,3 +175,22 @@ class Wavegrad(nn.Module): self.x_conv = weight_norm(self.x_conv) self.out_conv = weight_norm(self.out_conv) self.y_conv = weight_norm(self.y_conv) + + + def load_checkpoint(self, config, checkpoint_path, eval=False): + state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.load_state_dict(state['model']) + if eval: + self.eval() + assert not self.training + if self.use_weight_norm: + self.remove_weight_norm() + betas = np.linspace(config['test_noise_schedule']['min_val'], + config['test_noise_schedule']['max_val'], + config['test_noise_schedule']['num_steps']) + self.compute_noise_level(betas) + else: + betas = np.linspace(config['train_noise_schedule']['min_val'], + config['train_noise_schedule']['max_val'], + config['train_noise_schedule']['num_steps']) + self.compute_noise_level(betas) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index d6e2e13b..d17c84d4 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -1,4 +1,5 @@ import re +import torch import importlib import numpy as np from matplotlib import pyplot as plt