mirror of https://github.com/coqui-ai/TTS.git
refactor(vocoder): remove duplicate function
This commit is contained in:
parent
6ecf47312c
commit
0f69d31f70
|
@ -12,6 +12,13 @@ from TTS.vocoder.layers.upsample import ConvUpsample
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
|
||||
class ParallelWaveganGenerator(torch.nn.Module):
|
||||
"""PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf.
|
||||
It is similar to WaveNet with no causal convolution.
|
||||
|
@ -144,16 +151,9 @@ class ParallelWaveganGenerator(torch.nn.Module):
|
|||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False, cache=False
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
from torch.nn.utils import parametrize
|
||||
|
||||
from TTS.vocoder.layers.lvc_block import LVCBlock
|
||||
from TTS.vocoder.models.parallel_wavegan_generator import _get_receptive_field_size
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -133,17 +134,10 @@ class UnivnetGenerator(torch.nn.Module):
|
|||
|
||||
self.apply(_apply_weight_norm)
|
||||
|
||||
@staticmethod
|
||||
def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x):
|
||||
assert layers % stacks == 0
|
||||
layers_per_cycle = layers // stacks
|
||||
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
||||
return (kernel_size - 1) * sum(dilations) + 1
|
||||
|
||||
@property
|
||||
def receptive_field_size(self):
|
||||
"""Return receptive field size."""
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
|
|
Loading…
Reference in New Issue