refactor(vocoder): remove duplicate function

This commit is contained in:
Enno Hermann 2024-11-22 17:28:30 +01:00
parent 6ecf47312c
commit 0f69d31f70
2 changed files with 10 additions and 16 deletions

View File

@ -12,6 +12,13 @@ from TTS.vocoder.layers.upsample import ConvUpsample
logger = logging.getLogger(__name__)
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
class ParallelWaveganGenerator(torch.nn.Module):
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
It is similar to WaveNet with no causal convolution.
@ -144,16 +151,9 @@ class ParallelWaveganGenerator(torch.nn.Module):
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
def load_checkpoint(
self, config, checkpoint_path, eval=False, cache=False

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
from torch.nn.utils import parametrize
from TTS.vocoder.layers.lvc_block import LVCBlock
from TTS.vocoder.models.parallel_wavegan_generator import _get_receptive_field_size
logger = logging.getLogger(__name__)
@ -133,17 +134,10 @@ class UnivnetGenerator(torch.nn.Module):
self.apply(_apply_weight_norm)
@staticmethod
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
assert layers % stacks == 0
layers_per_cycle = layers // stacks
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
return (kernel_size - 1) * sum(dilations) + 1
@property
def receptive_field_size(self):
"""Return receptive field size."""
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
@torch.no_grad()
def inference(self, c):