mirror of https://github.com/coqui-ai/TTS.git
load_checkpoint func for vocoder models
This commit is contained in:
parent
ea39715305
commit
ca3743539a
|
@ -95,3 +95,11 @@ class MelganGenerator(nn.Module):
|
||||||
nn.utils.remove_weight_norm(layer)
|
nn.utils.remove_weight_norm(layer)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
layer.remove_weight_norm()
|
layer.remove_weight_norm()
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
self.load_state_dict(state['model'])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
self.remove_weight_norm()
|
||||||
|
|
|
@ -39,6 +39,7 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
self.upsample_factors = upsample_factors
|
self.upsample_factors = upsample_factors
|
||||||
self.upsample_scale = np.prod(upsample_factors)
|
self.upsample_scale = np.prod(upsample_factors)
|
||||||
self.inference_padding = inference_padding
|
self.inference_padding = inference_padding
|
||||||
|
self.use_weight_norm = use_weight_norm
|
||||||
|
|
||||||
# check the number of layers and stacks
|
# check the number of layers and stacks
|
||||||
assert num_res_blocks % stacks == 0
|
assert num_res_blocks % stacks == 0
|
||||||
|
@ -156,3 +157,12 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
||||||
def receptive_field_size(self):
|
def receptive_field_size(self):
|
||||||
return self._get_receptive_field_size(self.layers, self.stacks,
|
return self._get_receptive_field_size(self.layers, self.stacks,
|
||||||
self.kernel_size)
|
self.kernel_size)
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
self.load_state_dict(state['model'])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
if self.use_weight_norm:
|
||||||
|
self.remove_weight_norm()
|
||||||
|
|
|
@ -175,3 +175,22 @@ class Wavegrad(nn.Module):
|
||||||
self.x_conv = weight_norm(self.x_conv)
|
self.x_conv = weight_norm(self.x_conv)
|
||||||
self.out_conv = weight_norm(self.out_conv)
|
self.out_conv = weight_norm(self.out_conv)
|
||||||
self.y_conv = weight_norm(self.y_conv)
|
self.y_conv = weight_norm(self.y_conv)
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(self, config, checkpoint_path, eval=False):
|
||||||
|
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||||
|
self.load_state_dict(state['model'])
|
||||||
|
if eval:
|
||||||
|
self.eval()
|
||||||
|
assert not self.training
|
||||||
|
if self.use_weight_norm:
|
||||||
|
self.remove_weight_norm()
|
||||||
|
betas = np.linspace(config['test_noise_schedule']['min_val'],
|
||||||
|
config['test_noise_schedule']['max_val'],
|
||||||
|
config['test_noise_schedule']['num_steps'])
|
||||||
|
self.compute_noise_level(betas)
|
||||||
|
else:
|
||||||
|
betas = np.linspace(config['train_noise_schedule']['min_val'],
|
||||||
|
config['train_noise_schedule']['max_val'],
|
||||||
|
config['train_noise_schedule']['num_steps'])
|
||||||
|
self.compute_noise_level(betas)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import re
|
import re
|
||||||
|
import torch
|
||||||
import importlib
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
Loading…
Reference in New Issue