fix multiband melgan inference

This commit is contained in:
erogol 2020-06-09 23:02:55 +02:00
parent c866f23af6
commit 1dfafe003d
1 changed files with 2 additions and 1 deletions

View File

@ -29,10 +29,11 @@ class MultibandMelganGenerator(MelganGenerator):
def pqmf_synthesis(self, x):
return self.pqmf_layer.synthesis(x)
@torch.no_grad()
def inference(self, cond_features):
cond_features = cond_features.to(self.layers[1].weight.device)
cond_features = torch.nn.functional.pad(
cond_features,
(self.inference_padding, self.inference_padding),
'replicate')
return self.pqmf.synthesis(self.layers(cond_features))
return self.pqmf_synthesis(self.layers(cond_features))