renaming for melgan generator

This commit is contained in:
erogol 2020-06-19 12:25:03 +02:00
parent cc2104935e
commit 58784ad09c
3 changed files with 8 additions and 10 deletions

View File

@ -21,7 +21,7 @@ class ResidualStack(nn.Module):
nn.Conv1d(channels, nn.Conv1d(channels,
channels, channels,
kernel_size=kernel_size, kernel_size=kernel_size,
dilation=layer_padding, dilation=layer_dilation,
bias=True)), bias=True)),
nn.LeakyReLU(0.2), nn.LeakyReLU(0.2),
weight_norm( weight_norm(

View File

@ -1,5 +1,3 @@
"""Pseudo QMF modules."""
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F

View File

@ -77,16 +77,16 @@ class MelganGenerator(nn.Module):
] ]
self.layers = nn.Sequential(*layers) self.layers = nn.Sequential(*layers)
def forward(self, cond_features): def forward(self, c):
return self.layers(cond_features) return self.layers(c)
def inference(self, cond_features): def inference(self, c):
cond_features = cond_features.to(self.layers[1].weight.device) c = c.to(self.layers[1].weight.device)
cond_features = torch.nn.functional.pad( c = torch.nn.functional.pad(
cond_features, c,
(self.inference_padding, self.inference_padding), (self.inference_padding, self.inference_padding),
'replicate') 'replicate')
return self.layers(cond_features) return self.layers(c)
def remove_weight_norm(self): def remove_weight_norm(self):
for _, layer in enumerate(self.layers): for _, layer in enumerate(self.layers):