diff --git a/vocoder/models/melgan_generator.py b/vocoder/models/melgan_generator.py index 1e47816d..e69e6ef3 100644 --- a/vocoder/models/melgan_generator.py +++ b/vocoder/models/melgan_generator.py @@ -87,3 +87,12 @@ class MelganGenerator(nn.Module): (self.inference_padding, self.inference_padding), 'replicate') return self.layers(cond_features) + + def remove_weight_norm(self): + for _, layer in enumerate(self.layers): + if len(layer.state_dict()) != 0: + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + layer.remove_weight_norm() +